diff options
Diffstat (limited to 'lib/Analysis')
41 files changed, 18190 insertions, 0 deletions
diff --git a/lib/Analysis/AliasAnalysis.cpp b/lib/Analysis/AliasAnalysis.cpp new file mode 100644 index 0000000..c5523ec --- /dev/null +++ b/lib/Analysis/AliasAnalysis.cpp @@ -0,0 +1,248 @@ +//===- AliasAnalysis.cpp - Generic Alias Analysis Interface Implementation -==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the generic AliasAnalysis interface which is used as the +// common interface used by all clients and implementations of alias analysis. +// +// This file also implements the default version of the AliasAnalysis interface +// that is to be used when no other implementation is specified. This does some +// simple tests that detect obvious cases: two different global pointers cannot +// alias, a global cannot alias a malloc, two different mallocs cannot alias, +// etc. +// +// This alias analysis implementation really isn't very good for anything, but +// it is very fast, and makes a nice clean default implementation. Because it +// handles lots of little corner cases, other, more complex, alias analysis +// implementations may choose to rely on this pass to resolve these simple and +// easy cases. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Pass.h" +#include "llvm/BasicBlock.h" +#include "llvm/Function.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Instructions.h" +#include "llvm/Type.h" +#include "llvm/Target/TargetData.h" +using namespace llvm; + +// Register the AliasAnalysis interface, providing a nice name to refer to. +static RegisterAnalysisGroup<AliasAnalysis> Z("Alias Analysis"); +char AliasAnalysis::ID = 0; + +//===----------------------------------------------------------------------===// +// Default chaining methods +//===----------------------------------------------------------------------===// + +AliasAnalysis::AliasResult +AliasAnalysis::alias(const Value *V1, unsigned V1Size, + const Value *V2, unsigned V2Size) { + assert(AA && "AA didn't call InitializeAliasAnalysis in its run method!"); + return AA->alias(V1, V1Size, V2, V2Size); +} + +void AliasAnalysis::getMustAliases(Value *P, std::vector<Value*> &RetVals) { + assert(AA && "AA didn't call InitializeAliasAnalysis in its run method!"); + return AA->getMustAliases(P, RetVals); +} + +bool AliasAnalysis::pointsToConstantMemory(const Value *P) { + assert(AA && "AA didn't call InitializeAliasAnalysis in its run method!"); + return AA->pointsToConstantMemory(P); +} + +bool AliasAnalysis::hasNoModRefInfoForCalls() const { + assert(AA && "AA didn't call InitializeAliasAnalysis in its run method!"); + return AA->hasNoModRefInfoForCalls(); +} + +void AliasAnalysis::deleteValue(Value *V) { + assert(AA && "AA didn't call InitializeAliasAnalysis in its run method!"); + AA->deleteValue(V); +} + +void AliasAnalysis::copyValue(Value *From, Value *To) { + assert(AA && "AA didn't call InitializeAliasAnalysis in its run method!"); + AA->copyValue(From, To); +} + +AliasAnalysis::ModRefResult +AliasAnalysis::getModRefInfo(CallSite CS1, CallSite CS2) { + // FIXME: we can do better. + assert(AA && "AA didn't call InitializeAliasAnalysis in its run method!"); + return AA->getModRefInfo(CS1, CS2); +} + + +//===----------------------------------------------------------------------===// +// AliasAnalysis non-virtual helper method implementation +//===----------------------------------------------------------------------===// + +AliasAnalysis::ModRefResult +AliasAnalysis::getModRefInfo(LoadInst *L, Value *P, unsigned Size) { + return alias(L->getOperand(0), TD->getTypeStoreSize(L->getType()), + P, Size) ? Ref : NoModRef; +} + +AliasAnalysis::ModRefResult +AliasAnalysis::getModRefInfo(StoreInst *S, Value *P, unsigned Size) { + // If the stored address cannot alias the pointer in question, then the + // pointer cannot be modified by the store. + if (!alias(S->getOperand(1), + TD->getTypeStoreSize(S->getOperand(0)->getType()), P, Size)) + return NoModRef; + + // If the pointer is a pointer to constant memory, then it could not have been + // modified by this store. + return pointsToConstantMemory(P) ? NoModRef : Mod; +} + +AliasAnalysis::ModRefBehavior +AliasAnalysis::getModRefBehavior(CallSite CS, + std::vector<PointerAccessInfo> *Info) { + if (CS.doesNotAccessMemory()) + // Can't do better than this. + return DoesNotAccessMemory; + ModRefBehavior MRB = getModRefBehavior(CS.getCalledFunction(), Info); + if (MRB != DoesNotAccessMemory && CS.onlyReadsMemory()) + return OnlyReadsMemory; + return MRB; +} + +AliasAnalysis::ModRefBehavior +AliasAnalysis::getModRefBehavior(Function *F, + std::vector<PointerAccessInfo> *Info) { + if (F) { + if (F->doesNotAccessMemory()) + // Can't do better than this. + return DoesNotAccessMemory; + if (F->onlyReadsMemory()) + return OnlyReadsMemory; + if (unsigned id = F->getIntrinsicID()) { +#define GET_INTRINSIC_MODREF_BEHAVIOR +#include "llvm/Intrinsics.gen" +#undef GET_INTRINSIC_MODREF_BEHAVIOR + } + } + return UnknownModRefBehavior; +} + +AliasAnalysis::ModRefResult +AliasAnalysis::getModRefInfo(CallSite CS, Value *P, unsigned Size) { + ModRefResult Mask = ModRef; + ModRefBehavior MRB = getModRefBehavior(CS); + if (MRB == DoesNotAccessMemory) + return NoModRef; + else if (MRB == OnlyReadsMemory) + Mask = Ref; + else if (MRB == AliasAnalysis::AccessesArguments) { + bool doesAlias = false; + for (CallSite::arg_iterator AI = CS.arg_begin(), AE = CS.arg_end(); + AI != AE; ++AI) + if (alias(*AI, ~0U, P, Size) != NoAlias) { + doesAlias = true; + break; + } + + if (!doesAlias) + return NoModRef; + } + + if (!AA) return Mask; + + // If P points to a constant memory location, the call definitely could not + // modify the memory location. + if ((Mask & Mod) && AA->pointsToConstantMemory(P)) + Mask = ModRefResult(Mask & ~Mod); + + return ModRefResult(Mask & AA->getModRefInfo(CS, P, Size)); +} + +// AliasAnalysis destructor: DO NOT move this to the header file for +// AliasAnalysis or else clients of the AliasAnalysis class may not depend on +// the AliasAnalysis.o file in the current .a file, causing alias analysis +// support to not be included in the tool correctly! +// +AliasAnalysis::~AliasAnalysis() {} + +/// InitializeAliasAnalysis - Subclasses must call this method to initialize the +/// AliasAnalysis interface before any other methods are called. +/// +void AliasAnalysis::InitializeAliasAnalysis(Pass *P) { + TD = &P->getAnalysis<TargetData>(); + AA = &P->getAnalysis<AliasAnalysis>(); +} + +// getAnalysisUsage - All alias analysis implementations should invoke this +// directly (using AliasAnalysis::getAnalysisUsage(AU)) to make sure that +// TargetData is required by the pass. +void AliasAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<TargetData>(); // All AA's need TargetData. + AU.addRequired<AliasAnalysis>(); // All AA's chain +} + +/// canBasicBlockModify - Return true if it is possible for execution of the +/// specified basic block to modify the value pointed to by Ptr. +/// +bool AliasAnalysis::canBasicBlockModify(const BasicBlock &BB, + const Value *Ptr, unsigned Size) { + return canInstructionRangeModify(BB.front(), BB.back(), Ptr, Size); +} + +/// canInstructionRangeModify - Return true if it is possible for the execution +/// of the specified instructions to modify the value pointed to by Ptr. The +/// instructions to consider are all of the instructions in the range of [I1,I2] +/// INCLUSIVE. I1 and I2 must be in the same basic block. +/// +bool AliasAnalysis::canInstructionRangeModify(const Instruction &I1, + const Instruction &I2, + const Value *Ptr, unsigned Size) { + assert(I1.getParent() == I2.getParent() && + "Instructions not in same basic block!"); + BasicBlock::iterator I = const_cast<Instruction*>(&I1); + BasicBlock::iterator E = const_cast<Instruction*>(&I2); + ++E; // Convert from inclusive to exclusive range. + + for (; I != E; ++I) // Check every instruction in range + if (getModRefInfo(I, const_cast<Value*>(Ptr), Size) & Mod) + return true; + return false; +} + +/// isNoAliasCall - Return true if this pointer is returned by a noalias +/// function. +bool llvm::isNoAliasCall(const Value *V) { + if (isa<CallInst>(V) || isa<InvokeInst>(V)) + return CallSite(const_cast<Instruction*>(cast<Instruction>(V))) + .paramHasAttr(0, Attribute::NoAlias); + return false; +} + +/// isIdentifiedObject - Return true if this pointer refers to a distinct and +/// identifiable object. This returns true for: +/// Global Variables and Functions +/// Allocas and Mallocs +/// ByVal and NoAlias Arguments +/// NoAlias returns +/// +bool llvm::isIdentifiedObject(const Value *V) { + if (isa<GlobalValue>(V) || isa<AllocationInst>(V) || isNoAliasCall(V)) + return true; + if (const Argument *A = dyn_cast<Argument>(V)) + return A->hasNoAliasAttr() || A->hasByValAttr(); + return false; +} + +// Because of the way .a files work, we must force the BasicAA implementation to +// be pulled in if the AliasAnalysis classes are pulled in. Otherwise we run +// the risk of AliasAnalysis being used, but the default implementation not +// being linked into the tool that uses it. +DEFINING_FILE_FOR(AliasAnalysis) diff --git a/lib/Analysis/AliasAnalysisCounter.cpp b/lib/Analysis/AliasAnalysisCounter.cpp new file mode 100644 index 0000000..4362d7d --- /dev/null +++ b/lib/Analysis/AliasAnalysisCounter.cpp @@ -0,0 +1,173 @@ +//===- AliasAnalysisCounter.cpp - Alias Analysis Query Counter ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass which can be used to count how many alias queries +// are being made and how the alias analysis implementation being used responds. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Passes.h" +#include "llvm/Pass.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Streams.h" +using namespace llvm; + +static cl::opt<bool> +PrintAll("count-aa-print-all-queries", cl::ReallyHidden); +static cl::opt<bool> +PrintAllFailures("count-aa-print-all-failed-queries", cl::ReallyHidden); + +namespace { + class VISIBILITY_HIDDEN AliasAnalysisCounter + : public ModulePass, public AliasAnalysis { + unsigned No, May, Must; + unsigned NoMR, JustRef, JustMod, MR; + const char *Name; + Module *M; + public: + static char ID; // Class identification, replacement for typeinfo + AliasAnalysisCounter() : ModulePass(&ID) { + No = May = Must = 0; + NoMR = JustRef = JustMod = MR = 0; + } + + void printLine(const char *Desc, unsigned Val, unsigned Sum) { + cerr << " " << Val << " " << Desc << " responses (" + << Val*100/Sum << "%)\n"; + } + ~AliasAnalysisCounter() { + unsigned AASum = No+May+Must; + unsigned MRSum = NoMR+JustRef+JustMod+MR; + if (AASum + MRSum) { // Print a report if any counted queries occurred... + cerr << "\n===== Alias Analysis Counter Report =====\n" + << " Analysis counted: " << Name << "\n" + << " " << AASum << " Total Alias Queries Performed\n"; + if (AASum) { + printLine("no alias", No, AASum); + printLine("may alias", May, AASum); + printLine("must alias", Must, AASum); + cerr << " Alias Analysis Counter Summary: " << No*100/AASum << "%/" + << May*100/AASum << "%/" << Must*100/AASum<<"%\n\n"; + } + + cerr << " " << MRSum << " Total Mod/Ref Queries Performed\n"; + if (MRSum) { + printLine("no mod/ref", NoMR, MRSum); + printLine("ref", JustRef, MRSum); + printLine("mod", JustMod, MRSum); + printLine("mod/ref", MR, MRSum); + cerr << " Mod/Ref Analysis Counter Summary: " <<NoMR*100/MRSum<< "%/" + << JustRef*100/MRSum << "%/" << JustMod*100/MRSum << "%/" + << MR*100/MRSum <<"%\n\n"; + } + } + } + + bool runOnModule(Module &M) { + this->M = &M; + InitializeAliasAnalysis(this); + Name = dynamic_cast<Pass*>(&getAnalysis<AliasAnalysis>())->getPassName(); + return false; + } + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AliasAnalysis::getAnalysisUsage(AU); + AU.addRequired<AliasAnalysis>(); + AU.setPreservesAll(); + } + + // FIXME: We could count these too... + bool pointsToConstantMemory(const Value *P) { + return getAnalysis<AliasAnalysis>().pointsToConstantMemory(P); + } + bool doesNotAccessMemory(CallSite CS) { + return getAnalysis<AliasAnalysis>().doesNotAccessMemory(CS); + } + bool doesNotAccessMemory(Function *F) { + return getAnalysis<AliasAnalysis>().doesNotAccessMemory(F); + } + bool onlyReadsMemory(CallSite CS) { + return getAnalysis<AliasAnalysis>().onlyReadsMemory(CS); + } + bool onlyReadsMemory(Function *F) { + return getAnalysis<AliasAnalysis>().onlyReadsMemory(F); + } + + + // Forwarding functions: just delegate to a real AA implementation, counting + // the number of responses... + AliasResult alias(const Value *V1, unsigned V1Size, + const Value *V2, unsigned V2Size); + + ModRefResult getModRefInfo(CallSite CS, Value *P, unsigned Size); + ModRefResult getModRefInfo(CallSite CS1, CallSite CS2) { + return AliasAnalysis::getModRefInfo(CS1,CS2); + } + }; +} + +char AliasAnalysisCounter::ID = 0; +static RegisterPass<AliasAnalysisCounter> +X("count-aa", "Count Alias Analysis Query Responses", false, true); +static RegisterAnalysisGroup<AliasAnalysis> Y(X); + +ModulePass *llvm::createAliasAnalysisCounterPass() { + return new AliasAnalysisCounter(); +} + +AliasAnalysis::AliasResult +AliasAnalysisCounter::alias(const Value *V1, unsigned V1Size, + const Value *V2, unsigned V2Size) { + AliasResult R = getAnalysis<AliasAnalysis>().alias(V1, V1Size, V2, V2Size); + + const char *AliasString; + switch (R) { + default: assert(0 && "Unknown alias type!"); + case NoAlias: No++; AliasString = "No alias"; break; + case MayAlias: May++; AliasString = "May alias"; break; + case MustAlias: Must++; AliasString = "Must alias"; break; + } + + if (PrintAll || (PrintAllFailures && R == MayAlias)) { + cerr << AliasString << ":\t"; + cerr << "[" << V1Size << "B] "; + WriteAsOperand(*cerr.stream(), V1, true, M); + cerr << ", "; + cerr << "[" << V2Size << "B] "; + WriteAsOperand(*cerr.stream(), V2, true, M); + cerr << "\n"; + } + + return R; +} + +AliasAnalysis::ModRefResult +AliasAnalysisCounter::getModRefInfo(CallSite CS, Value *P, unsigned Size) { + ModRefResult R = getAnalysis<AliasAnalysis>().getModRefInfo(CS, P, Size); + + const char *MRString; + switch (R) { + default: assert(0 && "Unknown mod/ref type!"); + case NoModRef: NoMR++; MRString = "NoModRef"; break; + case Ref: JustRef++; MRString = "JustRef"; break; + case Mod: JustMod++; MRString = "JustMod"; break; + case ModRef: MR++; MRString = "ModRef"; break; + } + + if (PrintAll || (PrintAllFailures && R == ModRef)) { + cerr << MRString << ": Ptr: "; + cerr << "[" << Size << "B] "; + WriteAsOperand(*cerr.stream(), P, true, M); + cerr << "\t<->" << *CS.getInstruction(); + } + return R; +} diff --git a/lib/Analysis/AliasAnalysisEvaluator.cpp b/lib/Analysis/AliasAnalysisEvaluator.cpp new file mode 100644 index 0000000..07820e3 --- /dev/null +++ b/lib/Analysis/AliasAnalysisEvaluator.cpp @@ -0,0 +1,246 @@ +//===- AliasAnalysisEvaluator.cpp - Alias Analysis Accuracy Evaluator -----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple N^2 alias analysis accuracy evaluator. +// Basically, for each function in the program, it simply queries to see how the +// alias analysis implementation answers alias queries between each pair of +// pointers in the function. +// +// This is inspired and adapted from code by: Naveen Neelakantam, Francesco +// Spadini, and Wojciech Stryjewski. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Support/InstIterator.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Streams.h" +#include <set> +#include <sstream> +using namespace llvm; + +static cl::opt<bool> PrintAll("print-all-alias-modref-info", cl::ReallyHidden); + +static cl::opt<bool> PrintNoAlias("print-no-aliases", cl::ReallyHidden); +static cl::opt<bool> PrintMayAlias("print-may-aliases", cl::ReallyHidden); +static cl::opt<bool> PrintMustAlias("print-must-aliases", cl::ReallyHidden); + +static cl::opt<bool> PrintNoModRef("print-no-modref", cl::ReallyHidden); +static cl::opt<bool> PrintMod("print-mod", cl::ReallyHidden); +static cl::opt<bool> PrintRef("print-ref", cl::ReallyHidden); +static cl::opt<bool> PrintModRef("print-modref", cl::ReallyHidden); + +namespace { + class VISIBILITY_HIDDEN AAEval : public FunctionPass { + unsigned NoAlias, MayAlias, MustAlias; + unsigned NoModRef, Mod, Ref, ModRef; + + public: + static char ID; // Pass identification, replacement for typeid + AAEval() : FunctionPass(&ID) {} + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<AliasAnalysis>(); + AU.setPreservesAll(); + } + + bool doInitialization(Module &M) { + NoAlias = MayAlias = MustAlias = 0; + NoModRef = Mod = Ref = ModRef = 0; + + if (PrintAll) { + PrintNoAlias = PrintMayAlias = PrintMustAlias = true; + PrintNoModRef = PrintMod = PrintRef = PrintModRef = true; + } + return false; + } + + bool runOnFunction(Function &F); + bool doFinalization(Module &M); + }; +} + +char AAEval::ID = 0; +static RegisterPass<AAEval> +X("aa-eval", "Exhaustive Alias Analysis Precision Evaluator", false, true); + +FunctionPass *llvm::createAAEvalPass() { return new AAEval(); } + +static void PrintResults(const char *Msg, bool P, const Value *V1, const Value *V2, + const Module *M) { + if (P) { + std::stringstream s1, s2; + WriteAsOperand(s1, V1, true, M); + WriteAsOperand(s2, V2, true, M); + std::string o1(s1.str()), o2(s2.str()); + if (o2 < o1) + std::swap(o1, o2); + cerr << " " << Msg << ":\t" + << o1 << ", " + << o2 << "\n"; + } +} + +static inline void +PrintModRefResults(const char *Msg, bool P, Instruction *I, Value *Ptr, + Module *M) { + if (P) { + cerr << " " << Msg << ": Ptr: "; + WriteAsOperand(*cerr.stream(), Ptr, true, M); + cerr << "\t<->" << *I; + } +} + +bool AAEval::runOnFunction(Function &F) { + AliasAnalysis &AA = getAnalysis<AliasAnalysis>(); + + const TargetData &TD = AA.getTargetData(); + + std::set<Value *> Pointers; + std::set<CallSite> CallSites; + + for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) + if (isa<PointerType>(I->getType())) // Add all pointer arguments + Pointers.insert(I); + + for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) { + if (isa<PointerType>(I->getType())) // Add all pointer instructions + Pointers.insert(&*I); + Instruction &Inst = *I; + User::op_iterator OI = Inst.op_begin(); + CallSite CS = CallSite::get(&Inst); + if (CS.getInstruction() && + isa<Function>(CS.getCalledValue())) + ++OI; // Skip actual functions for direct function calls. + for (; OI != Inst.op_end(); ++OI) + if (isa<PointerType>((*OI)->getType()) && !isa<ConstantPointerNull>(*OI)) + Pointers.insert(*OI); + + if (CS.getInstruction()) CallSites.insert(CS); + } + + if (PrintNoAlias || PrintMayAlias || PrintMustAlias || + PrintNoModRef || PrintMod || PrintRef || PrintModRef) + cerr << "Function: " << F.getName() << ": " << Pointers.size() + << " pointers, " << CallSites.size() << " call sites\n"; + + // iterate over the worklist, and run the full (n^2)/2 disambiguations + for (std::set<Value *>::iterator I1 = Pointers.begin(), E = Pointers.end(); + I1 != E; ++I1) { + unsigned I1Size = 0; + const Type *I1ElTy = cast<PointerType>((*I1)->getType())->getElementType(); + if (I1ElTy->isSized()) I1Size = TD.getTypeStoreSize(I1ElTy); + + for (std::set<Value *>::iterator I2 = Pointers.begin(); I2 != I1; ++I2) { + unsigned I2Size = 0; + const Type *I2ElTy =cast<PointerType>((*I2)->getType())->getElementType(); + if (I2ElTy->isSized()) I2Size = TD.getTypeStoreSize(I2ElTy); + + switch (AA.alias(*I1, I1Size, *I2, I2Size)) { + case AliasAnalysis::NoAlias: + PrintResults("NoAlias", PrintNoAlias, *I1, *I2, F.getParent()); + ++NoAlias; break; + case AliasAnalysis::MayAlias: + PrintResults("MayAlias", PrintMayAlias, *I1, *I2, F.getParent()); + ++MayAlias; break; + case AliasAnalysis::MustAlias: + PrintResults("MustAlias", PrintMustAlias, *I1, *I2, F.getParent()); + ++MustAlias; break; + default: + cerr << "Unknown alias query result!\n"; + } + } + } + + // Mod/ref alias analysis: compare all pairs of calls and values + for (std::set<CallSite>::iterator C = CallSites.begin(), + Ce = CallSites.end(); C != Ce; ++C) { + Instruction *I = C->getInstruction(); + + for (std::set<Value *>::iterator V = Pointers.begin(), Ve = Pointers.end(); + V != Ve; ++V) { + unsigned Size = 0; + const Type *ElTy = cast<PointerType>((*V)->getType())->getElementType(); + if (ElTy->isSized()) Size = TD.getTypeStoreSize(ElTy); + + switch (AA.getModRefInfo(*C, *V, Size)) { + case AliasAnalysis::NoModRef: + PrintModRefResults("NoModRef", PrintNoModRef, I, *V, F.getParent()); + ++NoModRef; break; + case AliasAnalysis::Mod: + PrintModRefResults(" Mod", PrintMod, I, *V, F.getParent()); + ++Mod; break; + case AliasAnalysis::Ref: + PrintModRefResults(" Ref", PrintRef, I, *V, F.getParent()); + ++Ref; break; + case AliasAnalysis::ModRef: + PrintModRefResults(" ModRef", PrintModRef, I, *V, F.getParent()); + ++ModRef; break; + default: + cerr << "Unknown alias query result!\n"; + } + } + } + + return false; +} + +static void PrintPercent(unsigned Num, unsigned Sum) { + cerr << "(" << Num*100ULL/Sum << "." + << ((Num*1000ULL/Sum) % 10) << "%)\n"; +} + +bool AAEval::doFinalization(Module &M) { + unsigned AliasSum = NoAlias + MayAlias + MustAlias; + cerr << "===== Alias Analysis Evaluator Report =====\n"; + if (AliasSum == 0) { + cerr << " Alias Analysis Evaluator Summary: No pointers!\n"; + } else { + cerr << " " << AliasSum << " Total Alias Queries Performed\n"; + cerr << " " << NoAlias << " no alias responses "; + PrintPercent(NoAlias, AliasSum); + cerr << " " << MayAlias << " may alias responses "; + PrintPercent(MayAlias, AliasSum); + cerr << " " << MustAlias << " must alias responses "; + PrintPercent(MustAlias, AliasSum); + cerr << " Alias Analysis Evaluator Pointer Alias Summary: " + << NoAlias*100/AliasSum << "%/" << MayAlias*100/AliasSum << "%/" + << MustAlias*100/AliasSum << "%\n"; + } + + // Display the summary for mod/ref analysis + unsigned ModRefSum = NoModRef + Mod + Ref + ModRef; + if (ModRefSum == 0) { + cerr << " Alias Analysis Mod/Ref Evaluator Summary: no mod/ref!\n"; + } else { + cerr << " " << ModRefSum << " Total ModRef Queries Performed\n"; + cerr << " " << NoModRef << " no mod/ref responses "; + PrintPercent(NoModRef, ModRefSum); + cerr << " " << Mod << " mod responses "; + PrintPercent(Mod, ModRefSum); + cerr << " " << Ref << " ref responses "; + PrintPercent(Ref, ModRefSum); + cerr << " " << ModRef << " mod & ref responses "; + PrintPercent(ModRef, ModRefSum); + cerr << " Alias Analysis Evaluator Mod/Ref Summary: " + << NoModRef*100/ModRefSum << "%/" << Mod*100/ModRefSum << "%/" + << Ref*100/ModRefSum << "%/" << ModRef*100/ModRefSum << "%\n"; + } + + return false; +} diff --git a/lib/Analysis/AliasDebugger.cpp b/lib/Analysis/AliasDebugger.cpp new file mode 100644 index 0000000..1e82621 --- /dev/null +++ b/lib/Analysis/AliasDebugger.cpp @@ -0,0 +1,123 @@ +//===- AliasDebugger.cpp - Simple Alias Analysis Use Checker --------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This simple pass checks alias analysis users to ensure that if they +// create a new value, they do not query AA without informing it of the value. +// It acts as a shim over any other AA pass you want. +// +// Yes keeping track of every value in the program is expensive, but this is +// a debugging pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Passes.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Instructions.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Support/Compiler.h" +#include <set> +using namespace llvm; + +namespace { + + class VISIBILITY_HIDDEN AliasDebugger + : public ModulePass, public AliasAnalysis { + + //What we do is simple. Keep track of every value the AA could + //know about, and verify that queries are one of those. + //A query to a value that didn't exist when the AA was created + //means someone forgot to update the AA when creating new values + + std::set<const Value*> Vals; + + public: + static char ID; // Class identification, replacement for typeinfo + AliasDebugger() : ModulePass(&ID) {} + + bool runOnModule(Module &M) { + InitializeAliasAnalysis(this); // set up super class + + for(Module::global_iterator I = M.global_begin(), + E = M.global_end(); I != E; ++I) + Vals.insert(&*I); + + for(Module::iterator I = M.begin(), + E = M.end(); I != E; ++I){ + Vals.insert(&*I); + if(!I->isDeclaration()) { + for (Function::arg_iterator AI = I->arg_begin(), AE = I->arg_end(); + AI != AE; ++AI) + Vals.insert(&*AI); + for (Function::const_iterator FI = I->begin(), FE = I->end(); + FI != FE; ++FI) + for (BasicBlock::const_iterator BI = FI->begin(), BE = FI->end(); + BI != BE; ++BI) + Vals.insert(&*BI); + } + + } + return false; + } + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AliasAnalysis::getAnalysisUsage(AU); + AU.setPreservesAll(); // Does not transform code + } + + //------------------------------------------------ + // Implement the AliasAnalysis API + // + AliasResult alias(const Value *V1, unsigned V1Size, + const Value *V2, unsigned V2Size) { + assert(Vals.find(V1) != Vals.end() && "Never seen value in AA before"); + assert(Vals.find(V2) != Vals.end() && "Never seen value in AA before"); + return AliasAnalysis::alias(V1, V1Size, V2, V2Size); + } + + ModRefResult getModRefInfo(CallSite CS, Value *P, unsigned Size) { + assert(Vals.find(P) != Vals.end() && "Never seen value in AA before"); + return AliasAnalysis::getModRefInfo(CS, P, Size); + } + + ModRefResult getModRefInfo(CallSite CS1, CallSite CS2) { + return AliasAnalysis::getModRefInfo(CS1,CS2); + } + + void getMustAliases(Value *P, std::vector<Value*> &RetVals) { + assert(Vals.find(P) != Vals.end() && "Never seen value in AA before"); + return AliasAnalysis::getMustAliases(P, RetVals); + } + + bool pointsToConstantMemory(const Value *P) { + assert(Vals.find(P) != Vals.end() && "Never seen value in AA before"); + return AliasAnalysis::pointsToConstantMemory(P); + } + + virtual void deleteValue(Value *V) { + assert(Vals.find(V) != Vals.end() && "Never seen value in AA before"); + AliasAnalysis::deleteValue(V); + } + virtual void copyValue(Value *From, Value *To) { + Vals.insert(To); + AliasAnalysis::copyValue(From, To); + } + + }; +} + +char AliasDebugger::ID = 0; +static RegisterPass<AliasDebugger> +X("debug-aa", "AA use debugger", false, true); +static RegisterAnalysisGroup<AliasAnalysis> Y(X); + +Pass *llvm::createAliasDebugger() { return new AliasDebugger(); } + diff --git a/lib/Analysis/AliasSetTracker.cpp b/lib/Analysis/AliasSetTracker.cpp new file mode 100644 index 0000000..18c2b665 --- /dev/null +++ b/lib/Analysis/AliasSetTracker.cpp @@ -0,0 +1,608 @@ +//===- AliasSetTracker.cpp - Alias Sets Tracker implementation-------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the AliasSetTracker and AliasSet classes. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Pass.h" +#include "llvm/Type.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/InstIterator.h" +#include "llvm/Support/Streams.h" +using namespace llvm; + +/// mergeSetIn - Merge the specified alias set into this alias set. +/// +void AliasSet::mergeSetIn(AliasSet &AS, AliasSetTracker &AST) { + assert(!AS.Forward && "Alias set is already forwarding!"); + assert(!Forward && "This set is a forwarding set!!"); + + // Update the alias and access types of this set... + AccessTy |= AS.AccessTy; + AliasTy |= AS.AliasTy; + + if (AliasTy == MustAlias) { + // Check that these two merged sets really are must aliases. Since both + // used to be must-alias sets, we can just check any pointer from each set + // for aliasing. + AliasAnalysis &AA = AST.getAliasAnalysis(); + PointerRec *L = getSomePointer(); + PointerRec *R = AS.getSomePointer(); + + // If the pointers are not a must-alias pair, this set becomes a may alias. + if (AA.alias(L->getValue(), L->getSize(), R->getValue(), R->getSize()) + != AliasAnalysis::MustAlias) + AliasTy = MayAlias; + } + + if (CallSites.empty()) { // Merge call sites... + if (!AS.CallSites.empty()) + std::swap(CallSites, AS.CallSites); + } else if (!AS.CallSites.empty()) { + CallSites.insert(CallSites.end(), AS.CallSites.begin(), AS.CallSites.end()); + AS.CallSites.clear(); + } + + AS.Forward = this; // Forward across AS now... + addRef(); // AS is now pointing to us... + + // Merge the list of constituent pointers... + if (AS.PtrList) { + *PtrListEnd = AS.PtrList; + AS.PtrList->setPrevInList(PtrListEnd); + PtrListEnd = AS.PtrListEnd; + + AS.PtrList = 0; + AS.PtrListEnd = &AS.PtrList; + assert(*AS.PtrListEnd == 0 && "End of list is not null?"); + } +} + +void AliasSetTracker::removeAliasSet(AliasSet *AS) { + if (AliasSet *Fwd = AS->Forward) { + Fwd->dropRef(*this); + AS->Forward = 0; + } + AliasSets.erase(AS); +} + +void AliasSet::removeFromTracker(AliasSetTracker &AST) { + assert(RefCount == 0 && "Cannot remove non-dead alias set from tracker!"); + AST.removeAliasSet(this); +} + +void AliasSet::addPointer(AliasSetTracker &AST, PointerRec &Entry, + unsigned Size, bool KnownMustAlias) { + assert(!Entry.hasAliasSet() && "Entry already in set!"); + + // Check to see if we have to downgrade to _may_ alias. + if (isMustAlias() && !KnownMustAlias) + if (PointerRec *P = getSomePointer()) { + AliasAnalysis &AA = AST.getAliasAnalysis(); + AliasAnalysis::AliasResult Result = + AA.alias(P->getValue(), P->getSize(), Entry.getValue(), Size); + if (Result == AliasAnalysis::MayAlias) + AliasTy = MayAlias; + else // First entry of must alias must have maximum size! + P->updateSize(Size); + assert(Result != AliasAnalysis::NoAlias && "Cannot be part of must set!"); + } + + Entry.setAliasSet(this); + Entry.updateSize(Size); + + // Add it to the end of the list... + assert(*PtrListEnd == 0 && "End of list is not null?"); + *PtrListEnd = &Entry; + PtrListEnd = Entry.setPrevInList(PtrListEnd); + assert(*PtrListEnd == 0 && "End of list is not null?"); + addRef(); // Entry points to alias set... +} + +void AliasSet::addCallSite(CallSite CS, AliasAnalysis &AA) { + CallSites.push_back(CS); + + AliasAnalysis::ModRefBehavior Behavior = AA.getModRefBehavior(CS); + if (Behavior == AliasAnalysis::DoesNotAccessMemory) + return; + else if (Behavior == AliasAnalysis::OnlyReadsMemory) { + AliasTy = MayAlias; + AccessTy |= Refs; + return; + } + + // FIXME: This should use mod/ref information to make this not suck so bad + AliasTy = MayAlias; + AccessTy = ModRef; +} + +/// aliasesPointer - Return true if the specified pointer "may" (or must) +/// alias one of the members in the set. +/// +bool AliasSet::aliasesPointer(const Value *Ptr, unsigned Size, + AliasAnalysis &AA) const { + if (AliasTy == MustAlias) { + assert(CallSites.empty() && "Illegal must alias set!"); + + // If this is a set of MustAliases, only check to see if the pointer aliases + // SOME value in the set... + PointerRec *SomePtr = getSomePointer(); + assert(SomePtr && "Empty must-alias set??"); + return AA.alias(SomePtr->getValue(), SomePtr->getSize(), Ptr, Size); + } + + // If this is a may-alias set, we have to check all of the pointers in the set + // to be sure it doesn't alias the set... + for (iterator I = begin(), E = end(); I != E; ++I) + if (AA.alias(Ptr, Size, I.getPointer(), I.getSize())) + return true; + + // Check the call sites list and invoke list... + if (!CallSites.empty()) { + if (AA.hasNoModRefInfoForCalls()) + return true; + + for (unsigned i = 0, e = CallSites.size(); i != e; ++i) + if (AA.getModRefInfo(CallSites[i], const_cast<Value*>(Ptr), Size) + != AliasAnalysis::NoModRef) + return true; + } + + return false; +} + +bool AliasSet::aliasesCallSite(CallSite CS, AliasAnalysis &AA) const { + if (AA.doesNotAccessMemory(CS)) + return false; + + if (AA.hasNoModRefInfoForCalls()) + return true; + + for (unsigned i = 0, e = CallSites.size(); i != e; ++i) + if (AA.getModRefInfo(CallSites[i], CS) != AliasAnalysis::NoModRef || + AA.getModRefInfo(CS, CallSites[i]) != AliasAnalysis::NoModRef) + return true; + + for (iterator I = begin(), E = end(); I != E; ++I) + if (AA.getModRefInfo(CS, I.getPointer(), I.getSize()) != + AliasAnalysis::NoModRef) + return true; + + return false; +} + +void AliasSetTracker::clear() { + // Delete all the PointerRec entries. + for (DenseMap<Value*, AliasSet::PointerRec*>::iterator I = PointerMap.begin(), + E = PointerMap.end(); I != E; ++I) + I->second->eraseFromList(); + + PointerMap.clear(); + + // The alias sets should all be clear now. + AliasSets.clear(); +} + + +/// findAliasSetForPointer - Given a pointer, find the one alias set to put the +/// instruction referring to the pointer into. If there are multiple alias sets +/// that may alias the pointer, merge them together and return the unified set. +/// +AliasSet *AliasSetTracker::findAliasSetForPointer(const Value *Ptr, + unsigned Size) { + AliasSet *FoundSet = 0; + for (iterator I = begin(), E = end(); I != E; ++I) + if (!I->Forward && I->aliasesPointer(Ptr, Size, AA)) { + if (FoundSet == 0) { // If this is the first alias set ptr can go into. + FoundSet = I; // Remember it. + } else { // Otherwise, we must merge the sets. + FoundSet->mergeSetIn(*I, *this); // Merge in contents. + } + } + + return FoundSet; +} + +/// containsPointer - Return true if the specified location is represented by +/// this alias set, false otherwise. This does not modify the AST object or +/// alias sets. +bool AliasSetTracker::containsPointer(Value *Ptr, unsigned Size) const { + for (const_iterator I = begin(), E = end(); I != E; ++I) + if (!I->Forward && I->aliasesPointer(Ptr, Size, AA)) + return true; + return false; +} + + + +AliasSet *AliasSetTracker::findAliasSetForCallSite(CallSite CS) { + AliasSet *FoundSet = 0; + for (iterator I = begin(), E = end(); I != E; ++I) + if (!I->Forward && I->aliasesCallSite(CS, AA)) { + if (FoundSet == 0) { // If this is the first alias set ptr can go into. + FoundSet = I; // Remember it. + } else if (!I->Forward) { // Otherwise, we must merge the sets. + FoundSet->mergeSetIn(*I, *this); // Merge in contents. + } + } + + return FoundSet; +} + + + + +/// getAliasSetForPointer - Return the alias set that the specified pointer +/// lives in. +AliasSet &AliasSetTracker::getAliasSetForPointer(Value *Pointer, unsigned Size, + bool *New) { + AliasSet::PointerRec &Entry = getEntryFor(Pointer); + + // Check to see if the pointer is already known... + if (Entry.hasAliasSet()) { + Entry.updateSize(Size); + // Return the set! + return *Entry.getAliasSet(*this)->getForwardedTarget(*this); + } else if (AliasSet *AS = findAliasSetForPointer(Pointer, Size)) { + // Add it to the alias set it aliases... + AS->addPointer(*this, Entry, Size); + return *AS; + } else { + if (New) *New = true; + // Otherwise create a new alias set to hold the loaded pointer... + AliasSets.push_back(new AliasSet()); + AliasSets.back().addPointer(*this, Entry, Size); + return AliasSets.back(); + } +} + +bool AliasSetTracker::add(Value *Ptr, unsigned Size) { + bool NewPtr; + addPointer(Ptr, Size, AliasSet::NoModRef, NewPtr); + return NewPtr; +} + + +bool AliasSetTracker::add(LoadInst *LI) { + bool NewPtr; + AliasSet &AS = addPointer(LI->getOperand(0), + AA.getTargetData().getTypeStoreSize(LI->getType()), + AliasSet::Refs, NewPtr); + if (LI->isVolatile()) AS.setVolatile(); + return NewPtr; +} + +bool AliasSetTracker::add(StoreInst *SI) { + bool NewPtr; + Value *Val = SI->getOperand(0); + AliasSet &AS = addPointer(SI->getOperand(1), + AA.getTargetData().getTypeStoreSize(Val->getType()), + AliasSet::Mods, NewPtr); + if (SI->isVolatile()) AS.setVolatile(); + return NewPtr; +} + +bool AliasSetTracker::add(FreeInst *FI) { + bool NewPtr; + addPointer(FI->getOperand(0), ~0, AliasSet::Mods, NewPtr); + return NewPtr; +} + +bool AliasSetTracker::add(VAArgInst *VAAI) { + bool NewPtr; + addPointer(VAAI->getOperand(0), ~0, AliasSet::ModRef, NewPtr); + return NewPtr; +} + + +bool AliasSetTracker::add(CallSite CS) { + if (isa<DbgInfoIntrinsic>(CS.getInstruction())) + return true; // Ignore DbgInfo Intrinsics. + if (AA.doesNotAccessMemory(CS)) + return true; // doesn't alias anything + + AliasSet *AS = findAliasSetForCallSite(CS); + if (!AS) { + AliasSets.push_back(new AliasSet()); + AS = &AliasSets.back(); + AS->addCallSite(CS, AA); + return true; + } else { + AS->addCallSite(CS, AA); + return false; + } +} + +bool AliasSetTracker::add(Instruction *I) { + // Dispatch to one of the other add methods... + if (LoadInst *LI = dyn_cast<LoadInst>(I)) + return add(LI); + else if (StoreInst *SI = dyn_cast<StoreInst>(I)) + return add(SI); + else if (CallInst *CI = dyn_cast<CallInst>(I)) + return add(CI); + else if (InvokeInst *II = dyn_cast<InvokeInst>(I)) + return add(II); + else if (FreeInst *FI = dyn_cast<FreeInst>(I)) + return add(FI); + else if (VAArgInst *VAAI = dyn_cast<VAArgInst>(I)) + return add(VAAI); + return true; +} + +void AliasSetTracker::add(BasicBlock &BB) { + for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) + add(I); +} + +void AliasSetTracker::add(const AliasSetTracker &AST) { + assert(&AA == &AST.AA && + "Merging AliasSetTracker objects with different Alias Analyses!"); + + // Loop over all of the alias sets in AST, adding the pointers contained + // therein into the current alias sets. This can cause alias sets to be + // merged together in the current AST. + for (const_iterator I = AST.begin(), E = AST.end(); I != E; ++I) + if (!I->Forward) { // Ignore forwarding alias sets + AliasSet &AS = const_cast<AliasSet&>(*I); + + // If there are any call sites in the alias set, add them to this AST. + for (unsigned i = 0, e = AS.CallSites.size(); i != e; ++i) + add(AS.CallSites[i]); + + // Loop over all of the pointers in this alias set... + AliasSet::iterator I = AS.begin(), E = AS.end(); + bool X; + for (; I != E; ++I) { + AliasSet &NewAS = addPointer(I.getPointer(), I.getSize(), + (AliasSet::AccessType)AS.AccessTy, X); + if (AS.isVolatile()) NewAS.setVolatile(); + } + } +} + +/// remove - Remove the specified (potentially non-empty) alias set from the +/// tracker. +void AliasSetTracker::remove(AliasSet &AS) { + // Drop all call sites. + AS.CallSites.clear(); + + // Clear the alias set. + unsigned NumRefs = 0; + while (!AS.empty()) { + AliasSet::PointerRec *P = AS.PtrList; + + Value *ValToRemove = P->getValue(); + + // Unlink and delete entry from the list of values. + P->eraseFromList(); + + // Remember how many references need to be dropped. + ++NumRefs; + + // Finally, remove the entry. + PointerMap.erase(ValToRemove); + } + + // Stop using the alias set, removing it. + AS.RefCount -= NumRefs; + if (AS.RefCount == 0) + AS.removeFromTracker(*this); +} + +bool AliasSetTracker::remove(Value *Ptr, unsigned Size) { + AliasSet *AS = findAliasSetForPointer(Ptr, Size); + if (!AS) return false; + remove(*AS); + return true; +} + +bool AliasSetTracker::remove(LoadInst *LI) { + unsigned Size = AA.getTargetData().getTypeStoreSize(LI->getType()); + AliasSet *AS = findAliasSetForPointer(LI->getOperand(0), Size); + if (!AS) return false; + remove(*AS); + return true; +} + +bool AliasSetTracker::remove(StoreInst *SI) { + unsigned Size = + AA.getTargetData().getTypeStoreSize(SI->getOperand(0)->getType()); + AliasSet *AS = findAliasSetForPointer(SI->getOperand(1), Size); + if (!AS) return false; + remove(*AS); + return true; +} + +bool AliasSetTracker::remove(FreeInst *FI) { + AliasSet *AS = findAliasSetForPointer(FI->getOperand(0), ~0); + if (!AS) return false; + remove(*AS); + return true; +} + +bool AliasSetTracker::remove(VAArgInst *VAAI) { + AliasSet *AS = findAliasSetForPointer(VAAI->getOperand(0), ~0); + if (!AS) return false; + remove(*AS); + return true; +} + +bool AliasSetTracker::remove(CallSite CS) { + if (AA.doesNotAccessMemory(CS)) + return false; // doesn't alias anything + + AliasSet *AS = findAliasSetForCallSite(CS); + if (!AS) return false; + remove(*AS); + return true; +} + +bool AliasSetTracker::remove(Instruction *I) { + // Dispatch to one of the other remove methods... + if (LoadInst *LI = dyn_cast<LoadInst>(I)) + return remove(LI); + else if (StoreInst *SI = dyn_cast<StoreInst>(I)) + return remove(SI); + else if (CallInst *CI = dyn_cast<CallInst>(I)) + return remove(CI); + else if (FreeInst *FI = dyn_cast<FreeInst>(I)) + return remove(FI); + else if (VAArgInst *VAAI = dyn_cast<VAArgInst>(I)) + return remove(VAAI); + return true; +} + + +// deleteValue method - This method is used to remove a pointer value from the +// AliasSetTracker entirely. It should be used when an instruction is deleted +// from the program to update the AST. If you don't use this, you would have +// dangling pointers to deleted instructions. +// +void AliasSetTracker::deleteValue(Value *PtrVal) { + // Notify the alias analysis implementation that this value is gone. + AA.deleteValue(PtrVal); + + // If this is a call instruction, remove the callsite from the appropriate + // AliasSet. + CallSite CS = CallSite::get(PtrVal); + if (CS.getInstruction()) + if (!AA.doesNotAccessMemory(CS)) + if (AliasSet *AS = findAliasSetForCallSite(CS)) + AS->removeCallSite(CS); + + // First, look up the PointerRec for this pointer. + DenseMap<Value*, AliasSet::PointerRec*>::iterator I = PointerMap.find(PtrVal); + if (I == PointerMap.end()) return; // Noop + + // If we found one, remove the pointer from the alias set it is in. + AliasSet::PointerRec *PtrValEnt = I->second; + AliasSet *AS = PtrValEnt->getAliasSet(*this); + + // Unlink and delete from the list of values. + PtrValEnt->eraseFromList(); + + // Stop using the alias set. + AS->dropRef(*this); + + PointerMap.erase(I); +} + +// copyValue - This method should be used whenever a preexisting value in the +// program is copied or cloned, introducing a new value. Note that it is ok for +// clients that use this method to introduce the same value multiple times: if +// the tracker already knows about a value, it will ignore the request. +// +void AliasSetTracker::copyValue(Value *From, Value *To) { + // Notify the alias analysis implementation that this value is copied. + AA.copyValue(From, To); + + // First, look up the PointerRec for this pointer. + DenseMap<Value*, AliasSet::PointerRec*>::iterator I = PointerMap.find(From); + if (I == PointerMap.end()) + return; // Noop + assert(I->second->hasAliasSet() && "Dead entry?"); + + AliasSet::PointerRec &Entry = getEntryFor(To); + if (Entry.hasAliasSet()) return; // Already in the tracker! + + // Add it to the alias set it aliases... + I = PointerMap.find(From); + AliasSet *AS = I->second->getAliasSet(*this); + AS->addPointer(*this, Entry, I->second->getSize(), true); +} + + + +//===----------------------------------------------------------------------===// +// AliasSet/AliasSetTracker Printing Support +//===----------------------------------------------------------------------===// + +void AliasSet::print(std::ostream &OS) const { + OS << " AliasSet[" << (void*)this << "," << RefCount << "] "; + OS << (AliasTy == MustAlias ? "must" : "may") << " alias, "; + switch (AccessTy) { + case NoModRef: OS << "No access "; break; + case Refs : OS << "Ref "; break; + case Mods : OS << "Mod "; break; + case ModRef : OS << "Mod/Ref "; break; + default: assert(0 && "Bad value for AccessTy!"); + } + if (isVolatile()) OS << "[volatile] "; + if (Forward) + OS << " forwarding to " << (void*)Forward; + + + if (!empty()) { + OS << "Pointers: "; + for (iterator I = begin(), E = end(); I != E; ++I) { + if (I != begin()) OS << ", "; + WriteAsOperand(OS << "(", I.getPointer()); + OS << ", " << I.getSize() << ")"; + } + } + if (!CallSites.empty()) { + OS << "\n " << CallSites.size() << " Call Sites: "; + for (unsigned i = 0, e = CallSites.size(); i != e; ++i) { + if (i) OS << ", "; + WriteAsOperand(OS, CallSites[i].getCalledValue()); + } + } + OS << "\n"; +} + +void AliasSetTracker::print(std::ostream &OS) const { + OS << "Alias Set Tracker: " << AliasSets.size() << " alias sets for " + << PointerMap.size() << " pointer values.\n"; + for (const_iterator I = begin(), E = end(); I != E; ++I) + I->print(OS); + OS << "\n"; +} + +void AliasSet::dump() const { print (cerr); } +void AliasSetTracker::dump() const { print(cerr); } + +//===----------------------------------------------------------------------===// +// AliasSetPrinter Pass +//===----------------------------------------------------------------------===// + +namespace { + class VISIBILITY_HIDDEN AliasSetPrinter : public FunctionPass { + AliasSetTracker *Tracker; + public: + static char ID; // Pass identification, replacement for typeid + AliasSetPrinter() : FunctionPass(&ID) {} + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<AliasAnalysis>(); + } + + virtual bool runOnFunction(Function &F) { + Tracker = new AliasSetTracker(getAnalysis<AliasAnalysis>()); + + for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) + Tracker->add(&*I); + Tracker->print(cerr); + delete Tracker; + return false; + } + }; +} + +char AliasSetPrinter::ID = 0; +static RegisterPass<AliasSetPrinter> +X("print-alias-sets", "Alias Set Printer", false, true); diff --git a/lib/Analysis/Analysis.cpp b/lib/Analysis/Analysis.cpp new file mode 100644 index 0000000..493c6e8 --- /dev/null +++ b/lib/Analysis/Analysis.cpp @@ -0,0 +1,44 @@ +//===-- Analysis.cpp ------------------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm-c/Analysis.h" +#include "llvm/Analysis/Verifier.h" +#include <fstream> +#include <cstring> + +using namespace llvm; + +int LLVMVerifyModule(LLVMModuleRef M, LLVMVerifierFailureAction Action, + char **OutMessages) { + std::string Messages; + + int Result = verifyModule(*unwrap(M), + static_cast<VerifierFailureAction>(Action), + OutMessages? &Messages : 0); + + if (OutMessages) + *OutMessages = strdup(Messages.c_str()); + + return Result; +} + +int LLVMVerifyFunction(LLVMValueRef Fn, LLVMVerifierFailureAction Action) { + return verifyFunction(*unwrap<Function>(Fn), + static_cast<VerifierFailureAction>(Action)); +} + +void LLVMViewFunctionCFG(LLVMValueRef Fn) { + Function *F = unwrap<Function>(Fn); + F->viewCFG(); +} + +void LLVMViewFunctionCFGOnly(LLVMValueRef Fn) { + Function *F = unwrap<Function>(Fn); + F->viewCFGOnly(); +} diff --git a/lib/Analysis/BasicAliasAnalysis.cpp b/lib/Analysis/BasicAliasAnalysis.cpp new file mode 100644 index 0000000..d062045 --- /dev/null +++ b/lib/Analysis/BasicAliasAnalysis.cpp @@ -0,0 +1,838 @@ +//===- BasicAliasAnalysis.cpp - Local Alias Analysis Impl -----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the default implementation of the Alias Analysis interface +// that simply implements a few identities (two different globals cannot alias, +// etc), but otherwise does no analysis. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/GlobalVariable.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Pass.h" +#include "llvm/Target/TargetData.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/GetElementPtrTypeIterator.h" +#include "llvm/Support/ManagedStatic.h" +#include <algorithm> +using namespace llvm; + +//===----------------------------------------------------------------------===// +// Useful predicates +//===----------------------------------------------------------------------===// + +static const User *isGEP(const Value *V) { + if (isa<GetElementPtrInst>(V) || + (isa<ConstantExpr>(V) && + cast<ConstantExpr>(V)->getOpcode() == Instruction::GetElementPtr)) + return cast<User>(V); + return 0; +} + +static const Value *GetGEPOperands(const Value *V, + SmallVector<Value*, 16> &GEPOps) { + assert(GEPOps.empty() && "Expect empty list to populate!"); + GEPOps.insert(GEPOps.end(), cast<User>(V)->op_begin()+1, + cast<User>(V)->op_end()); + + // Accumulate all of the chained indexes into the operand array + V = cast<User>(V)->getOperand(0); + + while (const User *G = isGEP(V)) { + if (!isa<Constant>(GEPOps[0]) || isa<GlobalValue>(GEPOps[0]) || + !cast<Constant>(GEPOps[0])->isNullValue()) + break; // Don't handle folding arbitrary pointer offsets yet... + GEPOps.erase(GEPOps.begin()); // Drop the zero index + GEPOps.insert(GEPOps.begin(), G->op_begin()+1, G->op_end()); + V = G->getOperand(0); + } + return V; +} + +/// isKnownNonNull - Return true if we know that the specified value is never +/// null. +static bool isKnownNonNull(const Value *V) { + // Alloca never returns null, malloc might. + if (isa<AllocaInst>(V)) return true; + + // A byval argument is never null. + if (const Argument *A = dyn_cast<Argument>(V)) + return A->hasByValAttr(); + + // Global values are not null unless extern weak. + if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) + return !GV->hasExternalWeakLinkage(); + return false; +} + +/// isNonEscapingLocalObject - Return true if the pointer is to a function-local +/// object that never escapes from the function. +static bool isNonEscapingLocalObject(const Value *V) { + // If this is a local allocation, check to see if it escapes. + if (isa<AllocationInst>(V) || isNoAliasCall(V)) + return !PointerMayBeCaptured(V, false); + + // If this is an argument that corresponds to a byval or noalias argument, + // then it has not escaped before entering the function. Check if it escapes + // inside the function. + if (const Argument *A = dyn_cast<Argument>(V)) + if (A->hasByValAttr() || A->hasNoAliasAttr()) { + // Don't bother analyzing arguments already known not to escape. + if (A->hasNoCaptureAttr()) + return true; + return !PointerMayBeCaptured(V, false); + } + return false; +} + + +/// isObjectSmallerThan - Return true if we can prove that the object specified +/// by V is smaller than Size. +static bool isObjectSmallerThan(const Value *V, unsigned Size, + const TargetData &TD) { + const Type *AccessTy; + if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) { + AccessTy = GV->getType()->getElementType(); + } else if (const AllocationInst *AI = dyn_cast<AllocationInst>(V)) { + if (!AI->isArrayAllocation()) + AccessTy = AI->getType()->getElementType(); + else + return false; + } else if (const Argument *A = dyn_cast<Argument>(V)) { + if (A->hasByValAttr()) + AccessTy = cast<PointerType>(A->getType())->getElementType(); + else + return false; + } else { + return false; + } + + if (AccessTy->isSized()) + return TD.getTypeAllocSize(AccessTy) < Size; + return false; +} + +//===----------------------------------------------------------------------===// +// NoAA Pass +//===----------------------------------------------------------------------===// + +namespace { + /// NoAA - This class implements the -no-aa pass, which always returns "I + /// don't know" for alias queries. NoAA is unlike other alias analysis + /// implementations, in that it does not chain to a previous analysis. As + /// such it doesn't follow many of the rules that other alias analyses must. + /// + struct VISIBILITY_HIDDEN NoAA : public ImmutablePass, public AliasAnalysis { + static char ID; // Class identification, replacement for typeinfo + NoAA() : ImmutablePass(&ID) {} + explicit NoAA(void *PID) : ImmutablePass(PID) { } + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<TargetData>(); + } + + virtual void initializePass() { + TD = &getAnalysis<TargetData>(); + } + + virtual AliasResult alias(const Value *V1, unsigned V1Size, + const Value *V2, unsigned V2Size) { + return MayAlias; + } + + virtual void getArgumentAccesses(Function *F, CallSite CS, + std::vector<PointerAccessInfo> &Info) { + assert(0 && "This method may not be called on this function!"); + } + + virtual void getMustAliases(Value *P, std::vector<Value*> &RetVals) { } + virtual bool pointsToConstantMemory(const Value *P) { return false; } + virtual ModRefResult getModRefInfo(CallSite CS, Value *P, unsigned Size) { + return ModRef; + } + virtual ModRefResult getModRefInfo(CallSite CS1, CallSite CS2) { + return ModRef; + } + virtual bool hasNoModRefInfoForCalls() const { return true; } + + virtual void deleteValue(Value *V) {} + virtual void copyValue(Value *From, Value *To) {} + }; +} // End of anonymous namespace + +// Register this pass... +char NoAA::ID = 0; +static RegisterPass<NoAA> +U("no-aa", "No Alias Analysis (always returns 'may' alias)", true, true); + +// Declare that we implement the AliasAnalysis interface +static RegisterAnalysisGroup<AliasAnalysis> V(U); + +ImmutablePass *llvm::createNoAAPass() { return new NoAA(); } + +//===----------------------------------------------------------------------===// +// BasicAA Pass +//===----------------------------------------------------------------------===// + +namespace { + /// BasicAliasAnalysis - This is the default alias analysis implementation. + /// Because it doesn't chain to a previous alias analysis (like -no-aa), it + /// derives from the NoAA class. + struct VISIBILITY_HIDDEN BasicAliasAnalysis : public NoAA { + static char ID; // Class identification, replacement for typeinfo + BasicAliasAnalysis() : NoAA(&ID) {} + AliasResult alias(const Value *V1, unsigned V1Size, + const Value *V2, unsigned V2Size); + + ModRefResult getModRefInfo(CallSite CS, Value *P, unsigned Size); + ModRefResult getModRefInfo(CallSite CS1, CallSite CS2); + + /// hasNoModRefInfoForCalls - We can provide mod/ref information against + /// non-escaping allocations. + virtual bool hasNoModRefInfoForCalls() const { return false; } + + /// pointsToConstantMemory - Chase pointers until we find a (constant + /// global) or not. + bool pointsToConstantMemory(const Value *P); + + private: + // CheckGEPInstructions - Check two GEP instructions with known + // must-aliasing base pointers. This checks to see if the index expressions + // preclude the pointers from aliasing... + AliasResult + CheckGEPInstructions(const Type* BasePtr1Ty, + Value **GEP1Ops, unsigned NumGEP1Ops, unsigned G1Size, + const Type *BasePtr2Ty, + Value **GEP2Ops, unsigned NumGEP2Ops, unsigned G2Size); + }; +} // End of anonymous namespace + +// Register this pass... +char BasicAliasAnalysis::ID = 0; +static RegisterPass<BasicAliasAnalysis> +X("basicaa", "Basic Alias Analysis (default AA impl)", false, true); + +// Declare that we implement the AliasAnalysis interface +static RegisterAnalysisGroup<AliasAnalysis, true> Y(X); + +ImmutablePass *llvm::createBasicAliasAnalysisPass() { + return new BasicAliasAnalysis(); +} + + +/// pointsToConstantMemory - Chase pointers until we find a (constant +/// global) or not. +bool BasicAliasAnalysis::pointsToConstantMemory(const Value *P) { + if (const GlobalVariable *GV = + dyn_cast<GlobalVariable>(P->getUnderlyingObject())) + return GV->isConstant(); + return false; +} + + +// getModRefInfo - Check to see if the specified callsite can clobber the +// specified memory object. Since we only look at local properties of this +// function, we really can't say much about this query. We do, however, use +// simple "address taken" analysis on local objects. +// +AliasAnalysis::ModRefResult +BasicAliasAnalysis::getModRefInfo(CallSite CS, Value *P, unsigned Size) { + if (!isa<Constant>(P)) { + const Value *Object = P->getUnderlyingObject(); + + // If this is a tail call and P points to a stack location, we know that + // the tail call cannot access or modify the local stack. + // We cannot exclude byval arguments here; these belong to the caller of + // the current function not to the current function, and a tail callee + // may reference them. + if (isa<AllocaInst>(Object)) + if (CallInst *CI = dyn_cast<CallInst>(CS.getInstruction())) + if (CI->isTailCall()) + return NoModRef; + + // If the pointer is to a locally allocated object that does not escape, + // then the call can not mod/ref the pointer unless the call takes the + // argument without capturing it. + if (isNonEscapingLocalObject(Object) && CS.getInstruction() != Object) { + bool passedAsArg = false; + // TODO: Eventually only check 'nocapture' arguments. + for (CallSite::arg_iterator CI = CS.arg_begin(), CE = CS.arg_end(); + CI != CE; ++CI) + if (isa<PointerType>((*CI)->getType()) && + alias(cast<Value>(CI), ~0U, P, ~0U) != NoAlias) + passedAsArg = true; + + if (!passedAsArg) + return NoModRef; + } + } + + // The AliasAnalysis base class has some smarts, lets use them. + return AliasAnalysis::getModRefInfo(CS, P, Size); +} + + +AliasAnalysis::ModRefResult +BasicAliasAnalysis::getModRefInfo(CallSite CS1, CallSite CS2) { + // If CS1 or CS2 are readnone, they don't interact. + ModRefBehavior CS1B = AliasAnalysis::getModRefBehavior(CS1); + if (CS1B == DoesNotAccessMemory) return NoModRef; + + ModRefBehavior CS2B = AliasAnalysis::getModRefBehavior(CS2); + if (CS2B == DoesNotAccessMemory) return NoModRef; + + // If they both only read from memory, just return ref. + if (CS1B == OnlyReadsMemory && CS2B == OnlyReadsMemory) + return Ref; + + // Otherwise, fall back to NoAA (mod+ref). + return NoAA::getModRefInfo(CS1, CS2); +} + + +// alias - Provide a bunch of ad-hoc rules to disambiguate in common cases, such +// as array references. +// +AliasAnalysis::AliasResult +BasicAliasAnalysis::alias(const Value *V1, unsigned V1Size, + const Value *V2, unsigned V2Size) { + // Strip off any constant expression casts if they exist + if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V1)) + if (CE->isCast() && isa<PointerType>(CE->getOperand(0)->getType())) + V1 = CE->getOperand(0); + if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V2)) + if (CE->isCast() && isa<PointerType>(CE->getOperand(0)->getType())) + V2 = CE->getOperand(0); + + // Are we checking for alias of the same value? + if (V1 == V2) return MustAlias; + + if (!isa<PointerType>(V1->getType()) || !isa<PointerType>(V2->getType())) + return NoAlias; // Scalars cannot alias each other + + // Strip off cast instructions. Since V1 and V2 are pointers, they must be + // pointer<->pointer bitcasts. + if (const BitCastInst *I = dyn_cast<BitCastInst>(V1)) + return alias(I->getOperand(0), V1Size, V2, V2Size); + if (const BitCastInst *I = dyn_cast<BitCastInst>(V2)) + return alias(V1, V1Size, I->getOperand(0), V2Size); + + // Figure out what objects these things are pointing to if we can. + const Value *O1 = V1->getUnderlyingObject(); + const Value *O2 = V2->getUnderlyingObject(); + + if (O1 != O2) { + // If V1/V2 point to two different objects we know that we have no alias. + if (isIdentifiedObject(O1) && isIdentifiedObject(O2)) + return NoAlias; + + // Arguments can't alias with local allocations or noalias calls. + if ((isa<Argument>(O1) && (isa<AllocationInst>(O2) || isNoAliasCall(O2))) || + (isa<Argument>(O2) && (isa<AllocationInst>(O1) || isNoAliasCall(O1)))) + return NoAlias; + + // Most objects can't alias null. + if ((isa<ConstantPointerNull>(V2) && isKnownNonNull(O1)) || + (isa<ConstantPointerNull>(V1) && isKnownNonNull(O2))) + return NoAlias; + } + + // If the size of one access is larger than the entire object on the other + // side, then we know such behavior is undefined and can assume no alias. + const TargetData &TD = getTargetData(); + if ((V1Size != ~0U && isObjectSmallerThan(O2, V1Size, TD)) || + (V2Size != ~0U && isObjectSmallerThan(O1, V2Size, TD))) + return NoAlias; + + // If one pointer is the result of a call/invoke and the other is a + // non-escaping local object, then we know the object couldn't escape to a + // point where the call could return it. + if ((isa<CallInst>(O1) || isa<InvokeInst>(O1)) && + isNonEscapingLocalObject(O2) && O1 != O2) + return NoAlias; + if ((isa<CallInst>(O2) || isa<InvokeInst>(O2)) && + isNonEscapingLocalObject(O1) && O1 != O2) + return NoAlias; + + // If we have two gep instructions with must-alias'ing base pointers, figure + // out if the indexes to the GEP tell us anything about the derived pointer. + // Note that we also handle chains of getelementptr instructions as well as + // constant expression getelementptrs here. + // + if (isGEP(V1) && isGEP(V2)) { + const User *GEP1 = cast<User>(V1); + const User *GEP2 = cast<User>(V2); + + // If V1 and V2 are identical GEPs, just recurse down on both of them. + // This allows us to analyze things like: + // P = gep A, 0, i, 1 + // Q = gep B, 0, i, 1 + // by just analyzing A and B. This is even safe for variable indices. + if (GEP1->getType() == GEP2->getType() && + GEP1->getNumOperands() == GEP2->getNumOperands() && + GEP1->getOperand(0)->getType() == GEP2->getOperand(0)->getType() && + // All operands are the same, ignoring the base. + std::equal(GEP1->op_begin()+1, GEP1->op_end(), GEP2->op_begin()+1)) + return alias(GEP1->getOperand(0), V1Size, GEP2->getOperand(0), V2Size); + + + // Drill down into the first non-gep value, to test for must-aliasing of + // the base pointers. + while (isGEP(GEP1->getOperand(0)) && + GEP1->getOperand(1) == + Constant::getNullValue(GEP1->getOperand(1)->getType())) + GEP1 = cast<User>(GEP1->getOperand(0)); + const Value *BasePtr1 = GEP1->getOperand(0); + + while (isGEP(GEP2->getOperand(0)) && + GEP2->getOperand(1) == + Constant::getNullValue(GEP2->getOperand(1)->getType())) + GEP2 = cast<User>(GEP2->getOperand(0)); + const Value *BasePtr2 = GEP2->getOperand(0); + + // Do the base pointers alias? + AliasResult BaseAlias = alias(BasePtr1, ~0U, BasePtr2, ~0U); + if (BaseAlias == NoAlias) return NoAlias; + if (BaseAlias == MustAlias) { + // If the base pointers alias each other exactly, check to see if we can + // figure out anything about the resultant pointers, to try to prove + // non-aliasing. + + // Collect all of the chained GEP operands together into one simple place + SmallVector<Value*, 16> GEP1Ops, GEP2Ops; + BasePtr1 = GetGEPOperands(V1, GEP1Ops); + BasePtr2 = GetGEPOperands(V2, GEP2Ops); + + // If GetGEPOperands were able to fold to the same must-aliased pointer, + // do the comparison. + if (BasePtr1 == BasePtr2) { + AliasResult GAlias = + CheckGEPInstructions(BasePtr1->getType(), + &GEP1Ops[0], GEP1Ops.size(), V1Size, + BasePtr2->getType(), + &GEP2Ops[0], GEP2Ops.size(), V2Size); + if (GAlias != MayAlias) + return GAlias; + } + } + } + + // Check to see if these two pointers are related by a getelementptr + // instruction. If one pointer is a GEP with a non-zero index of the other + // pointer, we know they cannot alias. + // + if (isGEP(V2)) { + std::swap(V1, V2); + std::swap(V1Size, V2Size); + } + + if (V1Size != ~0U && V2Size != ~0U) + if (isGEP(V1)) { + SmallVector<Value*, 16> GEPOperands; + const Value *BasePtr = GetGEPOperands(V1, GEPOperands); + + AliasResult R = alias(BasePtr, V1Size, V2, V2Size); + if (R == MustAlias) { + // If there is at least one non-zero constant index, we know they cannot + // alias. + bool ConstantFound = false; + bool AllZerosFound = true; + for (unsigned i = 0, e = GEPOperands.size(); i != e; ++i) + if (const Constant *C = dyn_cast<Constant>(GEPOperands[i])) { + if (!C->isNullValue()) { + ConstantFound = true; + AllZerosFound = false; + break; + } + } else { + AllZerosFound = false; + } + + // If we have getelementptr <ptr>, 0, 0, 0, 0, ... and V2 must aliases + // the ptr, the end result is a must alias also. + if (AllZerosFound) + return MustAlias; + + if (ConstantFound) { + if (V2Size <= 1 && V1Size <= 1) // Just pointer check? + return NoAlias; + + // Otherwise we have to check to see that the distance is more than + // the size of the argument... build an index vector that is equal to + // the arguments provided, except substitute 0's for any variable + // indexes we find... + if (cast<PointerType>( + BasePtr->getType())->getElementType()->isSized()) { + for (unsigned i = 0; i != GEPOperands.size(); ++i) + if (!isa<ConstantInt>(GEPOperands[i])) + GEPOperands[i] = + Constant::getNullValue(GEPOperands[i]->getType()); + int64_t Offset = + getTargetData().getIndexedOffset(BasePtr->getType(), + &GEPOperands[0], + GEPOperands.size()); + + if (Offset >= (int64_t)V2Size || Offset <= -(int64_t)V1Size) + return NoAlias; + } + } + } + } + + return MayAlias; +} + +// This function is used to determine if the indices of two GEP instructions are +// equal. V1 and V2 are the indices. +static bool IndexOperandsEqual(Value *V1, Value *V2) { + if (V1->getType() == V2->getType()) + return V1 == V2; + if (Constant *C1 = dyn_cast<Constant>(V1)) + if (Constant *C2 = dyn_cast<Constant>(V2)) { + // Sign extend the constants to long types, if necessary + if (C1->getType() != Type::Int64Ty) + C1 = ConstantExpr::getSExt(C1, Type::Int64Ty); + if (C2->getType() != Type::Int64Ty) + C2 = ConstantExpr::getSExt(C2, Type::Int64Ty); + return C1 == C2; + } + return false; +} + +/// CheckGEPInstructions - Check two GEP instructions with known must-aliasing +/// base pointers. This checks to see if the index expressions preclude the +/// pointers from aliasing... +AliasAnalysis::AliasResult +BasicAliasAnalysis::CheckGEPInstructions( + const Type* BasePtr1Ty, Value **GEP1Ops, unsigned NumGEP1Ops, unsigned G1S, + const Type *BasePtr2Ty, Value **GEP2Ops, unsigned NumGEP2Ops, unsigned G2S) { + // We currently can't handle the case when the base pointers have different + // primitive types. Since this is uncommon anyway, we are happy being + // extremely conservative. + if (BasePtr1Ty != BasePtr2Ty) + return MayAlias; + + const PointerType *GEPPointerTy = cast<PointerType>(BasePtr1Ty); + + // Find the (possibly empty) initial sequence of equal values... which are not + // necessarily constants. + unsigned NumGEP1Operands = NumGEP1Ops, NumGEP2Operands = NumGEP2Ops; + unsigned MinOperands = std::min(NumGEP1Operands, NumGEP2Operands); + unsigned MaxOperands = std::max(NumGEP1Operands, NumGEP2Operands); + unsigned UnequalOper = 0; + while (UnequalOper != MinOperands && + IndexOperandsEqual(GEP1Ops[UnequalOper], GEP2Ops[UnequalOper])) { + // Advance through the type as we go... + ++UnequalOper; + if (const CompositeType *CT = dyn_cast<CompositeType>(BasePtr1Ty)) + BasePtr1Ty = CT->getTypeAtIndex(GEP1Ops[UnequalOper-1]); + else { + // If all operands equal each other, then the derived pointers must + // alias each other... + BasePtr1Ty = 0; + assert(UnequalOper == NumGEP1Operands && UnequalOper == NumGEP2Operands && + "Ran out of type nesting, but not out of operands?"); + return MustAlias; + } + } + + // If we have seen all constant operands, and run out of indexes on one of the + // getelementptrs, check to see if the tail of the leftover one is all zeros. + // If so, return mustalias. + if (UnequalOper == MinOperands) { + if (NumGEP1Ops < NumGEP2Ops) { + std::swap(GEP1Ops, GEP2Ops); + std::swap(NumGEP1Ops, NumGEP2Ops); + } + + bool AllAreZeros = true; + for (unsigned i = UnequalOper; i != MaxOperands; ++i) + if (!isa<Constant>(GEP1Ops[i]) || + !cast<Constant>(GEP1Ops[i])->isNullValue()) { + AllAreZeros = false; + break; + } + if (AllAreZeros) return MustAlias; + } + + + // So now we know that the indexes derived from the base pointers, + // which are known to alias, are different. We can still determine a + // no-alias result if there are differing constant pairs in the index + // chain. For example: + // A[i][0] != A[j][1] iff (&A[0][1]-&A[0][0] >= std::max(G1S, G2S)) + // + // We have to be careful here about array accesses. In particular, consider: + // A[1][0] vs A[0][i] + // In this case, we don't *know* that the array will be accessed in bounds: + // the index could even be negative. Because of this, we have to + // conservatively *give up* and return may alias. We disregard differing + // array subscripts that are followed by a variable index without going + // through a struct. + // + unsigned SizeMax = std::max(G1S, G2S); + if (SizeMax == ~0U) return MayAlias; // Avoid frivolous work. + + // Scan for the first operand that is constant and unequal in the + // two getelementptrs... + unsigned FirstConstantOper = UnequalOper; + for (; FirstConstantOper != MinOperands; ++FirstConstantOper) { + const Value *G1Oper = GEP1Ops[FirstConstantOper]; + const Value *G2Oper = GEP2Ops[FirstConstantOper]; + + if (G1Oper != G2Oper) // Found non-equal constant indexes... + if (Constant *G1OC = dyn_cast<ConstantInt>(const_cast<Value*>(G1Oper))) + if (Constant *G2OC = dyn_cast<ConstantInt>(const_cast<Value*>(G2Oper))){ + if (G1OC->getType() != G2OC->getType()) { + // Sign extend both operands to long. + if (G1OC->getType() != Type::Int64Ty) + G1OC = ConstantExpr::getSExt(G1OC, Type::Int64Ty); + if (G2OC->getType() != Type::Int64Ty) + G2OC = ConstantExpr::getSExt(G2OC, Type::Int64Ty); + GEP1Ops[FirstConstantOper] = G1OC; + GEP2Ops[FirstConstantOper] = G2OC; + } + + if (G1OC != G2OC) { + // Handle the "be careful" case above: if this is an array/vector + // subscript, scan for a subsequent variable array index. + if (const SequentialType *STy = + dyn_cast<SequentialType>(BasePtr1Ty)) { + const Type *NextTy = STy; + bool isBadCase = false; + + for (unsigned Idx = FirstConstantOper; + Idx != MinOperands && isa<SequentialType>(NextTy); ++Idx) { + const Value *V1 = GEP1Ops[Idx], *V2 = GEP2Ops[Idx]; + if (!isa<Constant>(V1) || !isa<Constant>(V2)) { + isBadCase = true; + break; + } + // If the array is indexed beyond the bounds of the static type + // at this level, it will also fall into the "be careful" case. + // It would theoretically be possible to analyze these cases, + // but for now just be conservatively correct. + if (const ArrayType *ATy = dyn_cast<ArrayType>(STy)) + if (cast<ConstantInt>(G1OC)->getZExtValue() >= + ATy->getNumElements() || + cast<ConstantInt>(G2OC)->getZExtValue() >= + ATy->getNumElements()) { + isBadCase = true; + break; + } + if (const VectorType *VTy = dyn_cast<VectorType>(STy)) + if (cast<ConstantInt>(G1OC)->getZExtValue() >= + VTy->getNumElements() || + cast<ConstantInt>(G2OC)->getZExtValue() >= + VTy->getNumElements()) { + isBadCase = true; + break; + } + STy = cast<SequentialType>(NextTy); + NextTy = cast<SequentialType>(NextTy)->getElementType(); + } + + if (isBadCase) G1OC = 0; + } + + // Make sure they are comparable (ie, not constant expressions), and + // make sure the GEP with the smaller leading constant is GEP1. + if (G1OC) { + Constant *Compare = ConstantExpr::getICmp(ICmpInst::ICMP_SGT, + G1OC, G2OC); + if (ConstantInt *CV = dyn_cast<ConstantInt>(Compare)) { + if (CV->getZExtValue()) { // If they are comparable and G2 > G1 + std::swap(GEP1Ops, GEP2Ops); // Make GEP1 < GEP2 + std::swap(NumGEP1Ops, NumGEP2Ops); + } + break; + } + } + } + } + BasePtr1Ty = cast<CompositeType>(BasePtr1Ty)->getTypeAtIndex(G1Oper); + } + + // No shared constant operands, and we ran out of common operands. At this + // point, the GEP instructions have run through all of their operands, and we + // haven't found evidence that there are any deltas between the GEP's. + // However, one GEP may have more operands than the other. If this is the + // case, there may still be hope. Check this now. + if (FirstConstantOper == MinOperands) { + // Make GEP1Ops be the longer one if there is a longer one. + if (NumGEP1Ops < NumGEP2Ops) { + std::swap(GEP1Ops, GEP2Ops); + std::swap(NumGEP1Ops, NumGEP2Ops); + } + + // Is there anything to check? + if (NumGEP1Ops > MinOperands) { + for (unsigned i = FirstConstantOper; i != MaxOperands; ++i) + if (isa<ConstantInt>(GEP1Ops[i]) && + !cast<ConstantInt>(GEP1Ops[i])->isZero()) { + // Yup, there's a constant in the tail. Set all variables to + // constants in the GEP instruction to make it suitable for + // TargetData::getIndexedOffset. + for (i = 0; i != MaxOperands; ++i) + if (!isa<ConstantInt>(GEP1Ops[i])) + GEP1Ops[i] = Constant::getNullValue(GEP1Ops[i]->getType()); + // Okay, now get the offset. This is the relative offset for the full + // instruction. + const TargetData &TD = getTargetData(); + int64_t Offset1 = TD.getIndexedOffset(GEPPointerTy, GEP1Ops, + NumGEP1Ops); + + // Now check without any constants at the end. + int64_t Offset2 = TD.getIndexedOffset(GEPPointerTy, GEP1Ops, + MinOperands); + + // Make sure we compare the absolute difference. + if (Offset1 > Offset2) + std::swap(Offset1, Offset2); + + // If the tail provided a bit enough offset, return noalias! + if ((uint64_t)(Offset2-Offset1) >= SizeMax) + return NoAlias; + // Otherwise break - we don't look for another constant in the tail. + break; + } + } + + // Couldn't find anything useful. + return MayAlias; + } + + // If there are non-equal constants arguments, then we can figure + // out a minimum known delta between the two index expressions... at + // this point we know that the first constant index of GEP1 is less + // than the first constant index of GEP2. + + // Advance BasePtr[12]Ty over this first differing constant operand. + BasePtr2Ty = cast<CompositeType>(BasePtr1Ty)-> + getTypeAtIndex(GEP2Ops[FirstConstantOper]); + BasePtr1Ty = cast<CompositeType>(BasePtr1Ty)-> + getTypeAtIndex(GEP1Ops[FirstConstantOper]); + + // We are going to be using TargetData::getIndexedOffset to determine the + // offset that each of the GEP's is reaching. To do this, we have to convert + // all variable references to constant references. To do this, we convert the + // initial sequence of array subscripts into constant zeros to start with. + const Type *ZeroIdxTy = GEPPointerTy; + for (unsigned i = 0; i != FirstConstantOper; ++i) { + if (!isa<StructType>(ZeroIdxTy)) + GEP1Ops[i] = GEP2Ops[i] = Constant::getNullValue(Type::Int32Ty); + + if (const CompositeType *CT = dyn_cast<CompositeType>(ZeroIdxTy)) + ZeroIdxTy = CT->getTypeAtIndex(GEP1Ops[i]); + } + + // We know that GEP1Ops[FirstConstantOper] & GEP2Ops[FirstConstantOper] are ok + + // Loop over the rest of the operands... + for (unsigned i = FirstConstantOper+1; i != MaxOperands; ++i) { + const Value *Op1 = i < NumGEP1Ops ? GEP1Ops[i] : 0; + const Value *Op2 = i < NumGEP2Ops ? GEP2Ops[i] : 0; + // If they are equal, use a zero index... + if (Op1 == Op2 && BasePtr1Ty == BasePtr2Ty) { + if (!isa<ConstantInt>(Op1)) + GEP1Ops[i] = GEP2Ops[i] = Constant::getNullValue(Op1->getType()); + // Otherwise, just keep the constants we have. + } else { + if (Op1) { + if (const ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { + // If this is an array index, make sure the array element is in range. + if (const ArrayType *AT = dyn_cast<ArrayType>(BasePtr1Ty)) { + if (Op1C->getZExtValue() >= AT->getNumElements()) + return MayAlias; // Be conservative with out-of-range accesses + } else if (const VectorType *VT = dyn_cast<VectorType>(BasePtr1Ty)) { + if (Op1C->getZExtValue() >= VT->getNumElements()) + return MayAlias; // Be conservative with out-of-range accesses + } + + } else { + // GEP1 is known to produce a value less than GEP2. To be + // conservatively correct, we must assume the largest possible + // constant is used in this position. This cannot be the initial + // index to the GEP instructions (because we know we have at least one + // element before this one with the different constant arguments), so + // we know that the current index must be into either a struct or + // array. Because we know it's not constant, this cannot be a + // structure index. Because of this, we can calculate the maximum + // value possible. + // + if (const ArrayType *AT = dyn_cast<ArrayType>(BasePtr1Ty)) + GEP1Ops[i] = ConstantInt::get(Type::Int64Ty,AT->getNumElements()-1); + else if (const VectorType *VT = dyn_cast<VectorType>(BasePtr1Ty)) + GEP1Ops[i] = ConstantInt::get(Type::Int64Ty,VT->getNumElements()-1); + } + } + + if (Op2) { + if (const ConstantInt *Op2C = dyn_cast<ConstantInt>(Op2)) { + // If this is an array index, make sure the array element is in range. + if (const ArrayType *AT = dyn_cast<ArrayType>(BasePtr2Ty)) { + if (Op2C->getZExtValue() >= AT->getNumElements()) + return MayAlias; // Be conservative with out-of-range accesses + } else if (const VectorType *VT = dyn_cast<VectorType>(BasePtr2Ty)) { + if (Op2C->getZExtValue() >= VT->getNumElements()) + return MayAlias; // Be conservative with out-of-range accesses + } + } else { // Conservatively assume the minimum value for this index + GEP2Ops[i] = Constant::getNullValue(Op2->getType()); + } + } + } + + if (BasePtr1Ty && Op1) { + if (const CompositeType *CT = dyn_cast<CompositeType>(BasePtr1Ty)) + BasePtr1Ty = CT->getTypeAtIndex(GEP1Ops[i]); + else + BasePtr1Ty = 0; + } + + if (BasePtr2Ty && Op2) { + if (const CompositeType *CT = dyn_cast<CompositeType>(BasePtr2Ty)) + BasePtr2Ty = CT->getTypeAtIndex(GEP2Ops[i]); + else + BasePtr2Ty = 0; + } + } + + if (GEPPointerTy->getElementType()->isSized()) { + int64_t Offset1 = + getTargetData().getIndexedOffset(GEPPointerTy, GEP1Ops, NumGEP1Ops); + int64_t Offset2 = + getTargetData().getIndexedOffset(GEPPointerTy, GEP2Ops, NumGEP2Ops); + assert(Offset1 != Offset2 && + "There is at least one different constant here!"); + + // Make sure we compare the absolute difference. + if (Offset1 > Offset2) + std::swap(Offset1, Offset2); + + if ((uint64_t)(Offset2-Offset1) >= SizeMax) { + //cerr << "Determined that these two GEP's don't alias [" + // << SizeMax << " bytes]: \n" << *GEP1 << *GEP2; + return NoAlias; + } + } + return MayAlias; +} + +// Make sure that anything that uses AliasAnalysis pulls in this file... +DEFINING_FILE_FOR(BasicAliasAnalysis) diff --git a/lib/Analysis/CFGPrinter.cpp b/lib/Analysis/CFGPrinter.cpp new file mode 100644 index 0000000..143220c --- /dev/null +++ b/lib/Analysis/CFGPrinter.cpp @@ -0,0 +1,221 @@ +//===- CFGPrinter.cpp - DOT printer for the control flow graph ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines a '-dot-cfg' analysis pass, which emits the +// cfg.<fnname>.dot file for each function in the program, with a graph of the +// CFG for that function. +// +// The other main feature of this file is that it implements the +// Function::viewCFG method, which is useful for debugging passes which operate +// on the CFG. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/Analysis/CFGPrinter.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Config/config.h" +#include <iosfwd> +#include <sstream> +#include <fstream> +using namespace llvm; + +/// CFGOnly flag - This is used to control whether or not the CFG graph printer +/// prints out the contents of basic blocks or not. This is acceptable because +/// this code is only really used for debugging purposes. +/// +static bool CFGOnly = false; + +namespace llvm { +template<> +struct DOTGraphTraits<const Function*> : public DefaultDOTGraphTraits { + static std::string getGraphName(const Function *F) { + return "CFG for '" + F->getName() + "' function"; + } + + static std::string getNodeLabel(const BasicBlock *Node, + const Function *Graph) { + if (CFGOnly && !Node->getName().empty()) + return Node->getName() + ":"; + + std::ostringstream Out; + if (CFGOnly) { + WriteAsOperand(Out, Node, false); + return Out.str(); + } + + if (Node->getName().empty()) { + WriteAsOperand(Out, Node, false); + Out << ":"; + } + + Out << *Node; + std::string OutStr = Out.str(); + if (OutStr[0] == '\n') OutStr.erase(OutStr.begin()); + + // Process string output to make it nicer... + for (unsigned i = 0; i != OutStr.length(); ++i) + if (OutStr[i] == '\n') { // Left justify + OutStr[i] = '\\'; + OutStr.insert(OutStr.begin()+i+1, 'l'); + } else if (OutStr[i] == ';') { // Delete comments! + unsigned Idx = OutStr.find('\n', i+1); // Find end of line + OutStr.erase(OutStr.begin()+i, OutStr.begin()+Idx); + --i; + } + + return OutStr; + } + + static std::string getEdgeSourceLabel(const BasicBlock *Node, + succ_const_iterator I) { + // Label source of conditional branches with "T" or "F" + if (const BranchInst *BI = dyn_cast<BranchInst>(Node->getTerminator())) + if (BI->isConditional()) + return (I == succ_begin(Node)) ? "T" : "F"; + return ""; + } +}; +} + +namespace { + struct VISIBILITY_HIDDEN CFGViewer : public FunctionPass { + static char ID; // Pass identifcation, replacement for typeid + CFGViewer() : FunctionPass(&ID) {} + + virtual bool runOnFunction(Function &F) { + F.viewCFG(); + return false; + } + + void print(std::ostream &OS, const Module* = 0) const {} + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + } + }; +} + +char CFGViewer::ID = 0; +static RegisterPass<CFGViewer> +V0("view-cfg", "View CFG of function", false, true); + +namespace { + struct VISIBILITY_HIDDEN CFGOnlyViewer : public FunctionPass { + static char ID; // Pass identifcation, replacement for typeid + CFGOnlyViewer() : FunctionPass(&ID) {} + + virtual bool runOnFunction(Function &F) { + CFGOnly = true; + F.viewCFG(); + CFGOnly = false; + return false; + } + + void print(std::ostream &OS, const Module* = 0) const {} + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + } + }; +} + +char CFGOnlyViewer::ID = 0; +static RegisterPass<CFGOnlyViewer> +V1("view-cfg-only", + "View CFG of function (with no function bodies)", false, true); + +namespace { + struct VISIBILITY_HIDDEN CFGPrinter : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + CFGPrinter() : FunctionPass(&ID) {} + explicit CFGPrinter(void *pid) : FunctionPass(pid) {} + + virtual bool runOnFunction(Function &F) { + std::string Filename = "cfg." + F.getName() + ".dot"; + cerr << "Writing '" << Filename << "'..."; + std::ofstream File(Filename.c_str()); + + if (File.good()) + WriteGraph(File, (const Function*)&F); + else + cerr << " error opening file for writing!"; + cerr << "\n"; + return false; + } + + void print(std::ostream &OS, const Module* = 0) const {} + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + } + }; +} + +char CFGPrinter::ID = 0; +static RegisterPass<CFGPrinter> +P1("dot-cfg", "Print CFG of function to 'dot' file", false, true); + +namespace { + struct VISIBILITY_HIDDEN CFGOnlyPrinter : public CFGPrinter { + static char ID; // Pass identification, replacement for typeid + CFGOnlyPrinter() : CFGPrinter(&ID) {} + virtual bool runOnFunction(Function &F) { + bool OldCFGOnly = CFGOnly; + CFGOnly = true; + CFGPrinter::runOnFunction(F); + CFGOnly = OldCFGOnly; + return false; + } + void print(std::ostream &OS, const Module* = 0) const {} + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + } + }; +} + +char CFGOnlyPrinter::ID = 0; +static RegisterPass<CFGOnlyPrinter> +P2("dot-cfg-only", + "Print CFG of function to 'dot' file (with no function bodies)", false, true); + +/// viewCFG - This function is meant for use from the debugger. You can just +/// say 'call F->viewCFG()' and a ghostview window should pop up from the +/// program, displaying the CFG of the current function. This depends on there +/// being a 'dot' and 'gv' program in your path. +/// +void Function::viewCFG() const { + ViewGraph(this, "cfg" + getName()); +} + +/// viewCFGOnly - This function is meant for use from the debugger. It works +/// just like viewCFG, but it does not include the contents of basic blocks +/// into the nodes, just the label. If you are only interested in the CFG t +/// his can make the graph smaller. +/// +void Function::viewCFGOnly() const { + CFGOnly = true; + viewCFG(); + CFGOnly = false; +} + +FunctionPass *llvm::createCFGPrinterPass () { + return new CFGPrinter(); +} + +FunctionPass *llvm::createCFGOnlyPrinterPass () { + return new CFGOnlyPrinter(); +} + diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt new file mode 100644 index 0000000..093aa69 --- /dev/null +++ b/lib/Analysis/CMakeLists.txt @@ -0,0 +1,34 @@ +add_llvm_library(LLVMAnalysis + AliasAnalysis.cpp + AliasAnalysisCounter.cpp + AliasAnalysisEvaluator.cpp + AliasDebugger.cpp + AliasSetTracker.cpp + Analysis.cpp + BasicAliasAnalysis.cpp + CaptureTracking.cpp + CFGPrinter.cpp + ConstantFolding.cpp + DbgInfoPrinter.cpp + DebugInfo.cpp + InstCount.cpp + Interval.cpp + IntervalPartition.cpp + IVUsers.cpp + LibCallAliasAnalysis.cpp + LibCallSemantics.cpp + LiveValues.cpp + LoopInfo.cpp + LoopPass.cpp + LoopVR.cpp + MemoryDependenceAnalysis.cpp + PostDominators.cpp + ProfileInfo.cpp + ProfileInfoLoader.cpp + ProfileInfoLoaderPass.cpp + ScalarEvolution.cpp + ScalarEvolutionExpander.cpp + SparsePropagation.cpp + Trace.cpp + ValueTracking.cpp + ) diff --git a/lib/Analysis/CaptureTracking.cpp b/lib/Analysis/CaptureTracking.cpp new file mode 100644 index 0000000..a19b8e4 --- /dev/null +++ b/lib/Analysis/CaptureTracking.cpp @@ -0,0 +1,112 @@ +//===--- CaptureTracking.cpp - Determine whether a pointer is captured ----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains routines that help determine which pointers are captured. +// A pointer value is captured if the function makes a copy of any part of the +// pointer that outlives the call. Not being captured means, more or less, that +// the pointer is only dereferenced and not stored in a global. Returning part +// of the pointer as the function return value may or may not count as capturing +// the pointer, depending on the context. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Instructions.h" +#include "llvm/Value.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/CallSite.h" +using namespace llvm; + +/// PointerMayBeCaptured - Return true if this pointer value may be captured +/// by the enclosing function (which is required to exist). This routine can +/// be expensive, so consider caching the results. The boolean ReturnCaptures +/// specifies whether returning the value (or part of it) from the function +/// counts as capturing it or not. +bool llvm::PointerMayBeCaptured(const Value *V, bool ReturnCaptures) { + assert(isa<PointerType>(V->getType()) && "Capture is for pointers only!"); + SmallVector<Use*, 16> Worklist; + SmallSet<Use*, 16> Visited; + + for (Value::use_const_iterator UI = V->use_begin(), UE = V->use_end(); + UI != UE; ++UI) { + Use *U = &UI.getUse(); + Visited.insert(U); + Worklist.push_back(U); + } + + while (!Worklist.empty()) { + Use *U = Worklist.pop_back_val(); + Instruction *I = cast<Instruction>(U->getUser()); + V = U->get(); + + switch (I->getOpcode()) { + case Instruction::Call: + case Instruction::Invoke: { + CallSite CS = CallSite::get(I); + // Not captured if the callee is readonly, doesn't return a copy through + // its return value and doesn't unwind (a readonly function can leak bits + // by throwing an exception or not depending on the input value). + if (CS.onlyReadsMemory() && CS.doesNotThrow() && + I->getType() == Type::VoidTy) + break; + + // Not captured if only passed via 'nocapture' arguments. Note that + // calling a function pointer does not in itself cause the pointer to + // be captured. This is a subtle point considering that (for example) + // the callee might return its own address. It is analogous to saying + // that loading a value from a pointer does not cause the pointer to be + // captured, even though the loaded value might be the pointer itself + // (think of self-referential objects). + CallSite::arg_iterator B = CS.arg_begin(), E = CS.arg_end(); + for (CallSite::arg_iterator A = B; A != E; ++A) + if (A->get() == V && !CS.paramHasAttr(A - B + 1, Attribute::NoCapture)) + // The parameter is not marked 'nocapture' - captured. + return true; + // Only passed via 'nocapture' arguments, or is the called function - not + // captured. + break; + } + case Instruction::Free: + // Freeing a pointer does not cause it to be captured. + break; + case Instruction::Load: + // Loading from a pointer does not cause it to be captured. + break; + case Instruction::Ret: + if (ReturnCaptures) + return true; + break; + case Instruction::Store: + if (V == I->getOperand(0)) + // Stored the pointer - it may be captured. + return true; + // Storing to the pointee does not cause the pointer to be captured. + break; + case Instruction::BitCast: + case Instruction::GetElementPtr: + case Instruction::PHI: + case Instruction::Select: + // The original value is not captured via this if the new value isn't. + for (Instruction::use_iterator UI = I->use_begin(), UE = I->use_end(); + UI != UE; ++UI) { + Use *U = &UI.getUse(); + if (Visited.insert(U)) + Worklist.push_back(U); + } + break; + default: + // Something else - be conservative and say it is captured. + return true; + } + } + + // All uses examined - not captured. + return false; +} diff --git a/lib/Analysis/ConstantFolding.cpp b/lib/Analysis/ConstantFolding.cpp new file mode 100644 index 0000000..e5ab322 --- /dev/null +++ b/lib/Analysis/ConstantFolding.cpp @@ -0,0 +1,829 @@ +//===-- ConstantFolding.cpp - Analyze constant folding possibilities ------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This family of functions determines the possibility of performing constant +// folding. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Function.h" +#include "llvm/GlobalVariable.h" +#include "llvm/Instructions.h" +#include "llvm/Intrinsics.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Support/GetElementPtrTypeIterator.h" +#include "llvm/Support/MathExtras.h" +#include <cerrno> +#include <cmath> +using namespace llvm; + +//===----------------------------------------------------------------------===// +// Constant Folding internal helper functions +//===----------------------------------------------------------------------===// + +/// IsConstantOffsetFromGlobal - If this constant is actually a constant offset +/// from a global, return the global and the constant. Because of +/// constantexprs, this function is recursive. +static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV, + int64_t &Offset, const TargetData &TD) { + // Trivial case, constant is the global. + if ((GV = dyn_cast<GlobalValue>(C))) { + Offset = 0; + return true; + } + + // Otherwise, if this isn't a constant expr, bail out. + ConstantExpr *CE = dyn_cast<ConstantExpr>(C); + if (!CE) return false; + + // Look through ptr->int and ptr->ptr casts. + if (CE->getOpcode() == Instruction::PtrToInt || + CE->getOpcode() == Instruction::BitCast) + return IsConstantOffsetFromGlobal(CE->getOperand(0), GV, Offset, TD); + + // i32* getelementptr ([5 x i32]* @a, i32 0, i32 5) + if (CE->getOpcode() == Instruction::GetElementPtr) { + // Cannot compute this if the element type of the pointer is missing size + // info. + if (!cast<PointerType>(CE->getOperand(0)->getType()) + ->getElementType()->isSized()) + return false; + + // If the base isn't a global+constant, we aren't either. + if (!IsConstantOffsetFromGlobal(CE->getOperand(0), GV, Offset, TD)) + return false; + + // Otherwise, add any offset that our operands provide. + gep_type_iterator GTI = gep_type_begin(CE); + for (User::const_op_iterator i = CE->op_begin() + 1, e = CE->op_end(); + i != e; ++i, ++GTI) { + ConstantInt *CI = dyn_cast<ConstantInt>(*i); + if (!CI) return false; // Index isn't a simple constant? + if (CI->getZExtValue() == 0) continue; // Not adding anything. + + if (const StructType *ST = dyn_cast<StructType>(*GTI)) { + // N = N + Offset + Offset += TD.getStructLayout(ST)->getElementOffset(CI->getZExtValue()); + } else { + const SequentialType *SQT = cast<SequentialType>(*GTI); + Offset += TD.getTypeAllocSize(SQT->getElementType())*CI->getSExtValue(); + } + } + return true; + } + + return false; +} + + +/// SymbolicallyEvaluateBinop - One of Op0/Op1 is a constant expression. +/// Attempt to symbolically evaluate the result of a binary operator merging +/// these together. If target data info is available, it is provided as TD, +/// otherwise TD is null. +static Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0, + Constant *Op1, const TargetData *TD){ + // SROA + + // Fold (and 0xffffffff00000000, (shl x, 32)) -> shl. + // Fold (lshr (or X, Y), 32) -> (lshr [X/Y], 32) if one doesn't contribute + // bits. + + + // If the constant expr is something like &A[123] - &A[4].f, fold this into a + // constant. This happens frequently when iterating over a global array. + if (Opc == Instruction::Sub && TD) { + GlobalValue *GV1, *GV2; + int64_t Offs1, Offs2; + + if (IsConstantOffsetFromGlobal(Op0, GV1, Offs1, *TD)) + if (IsConstantOffsetFromGlobal(Op1, GV2, Offs2, *TD) && + GV1 == GV2) { + // (&GV+C1) - (&GV+C2) -> C1-C2, pointer arithmetic cannot overflow. + return ConstantInt::get(Op0->getType(), Offs1-Offs2); + } + } + + return 0; +} + +/// SymbolicallyEvaluateGEP - If we can symbolically evaluate the specified GEP +/// constant expression, do so. +static Constant *SymbolicallyEvaluateGEP(Constant* const* Ops, unsigned NumOps, + const Type *ResultTy, + const TargetData *TD) { + Constant *Ptr = Ops[0]; + if (!TD || !cast<PointerType>(Ptr->getType())->getElementType()->isSized()) + return 0; + + uint64_t BasePtr = 0; + if (!Ptr->isNullValue()) { + // If this is a inttoptr from a constant int, we can fold this as the base, + // otherwise we can't. + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) + if (CE->getOpcode() == Instruction::IntToPtr) + if (ConstantInt *Base = dyn_cast<ConstantInt>(CE->getOperand(0))) + BasePtr = Base->getZExtValue(); + + if (BasePtr == 0) + return 0; + } + + // If this is a constant expr gep that is effectively computing an + // "offsetof", fold it into 'cast int Size to T*' instead of 'gep 0, 0, 12' + for (unsigned i = 1; i != NumOps; ++i) + if (!isa<ConstantInt>(Ops[i])) + return false; + + uint64_t Offset = TD->getIndexedOffset(Ptr->getType(), + (Value**)Ops+1, NumOps-1); + Constant *C = ConstantInt::get(TD->getIntPtrType(), Offset+BasePtr); + return ConstantExpr::getIntToPtr(C, ResultTy); +} + +/// FoldBitCast - Constant fold bitcast, symbolically evaluating it with +/// targetdata. Return 0 if unfoldable. +static Constant *FoldBitCast(Constant *C, const Type *DestTy, + const TargetData &TD) { + // If this is a bitcast from constant vector -> vector, fold it. + if (ConstantVector *CV = dyn_cast<ConstantVector>(C)) { + if (const VectorType *DestVTy = dyn_cast<VectorType>(DestTy)) { + // If the element types match, VMCore can fold it. + unsigned NumDstElt = DestVTy->getNumElements(); + unsigned NumSrcElt = CV->getNumOperands(); + if (NumDstElt == NumSrcElt) + return 0; + + const Type *SrcEltTy = CV->getType()->getElementType(); + const Type *DstEltTy = DestVTy->getElementType(); + + // Otherwise, we're changing the number of elements in a vector, which + // requires endianness information to do the right thing. For example, + // bitcast (<2 x i64> <i64 0, i64 1> to <4 x i32>) + // folds to (little endian): + // <4 x i32> <i32 0, i32 0, i32 1, i32 0> + // and to (big endian): + // <4 x i32> <i32 0, i32 0, i32 0, i32 1> + + // First thing is first. We only want to think about integer here, so if + // we have something in FP form, recast it as integer. + if (DstEltTy->isFloatingPoint()) { + // Fold to an vector of integers with same size as our FP type. + unsigned FPWidth = DstEltTy->getPrimitiveSizeInBits(); + const Type *DestIVTy = VectorType::get(IntegerType::get(FPWidth), + NumDstElt); + // Recursively handle this integer conversion, if possible. + C = FoldBitCast(C, DestIVTy, TD); + if (!C) return 0; + + // Finally, VMCore can handle this now that #elts line up. + return ConstantExpr::getBitCast(C, DestTy); + } + + // Okay, we know the destination is integer, if the input is FP, convert + // it to integer first. + if (SrcEltTy->isFloatingPoint()) { + unsigned FPWidth = SrcEltTy->getPrimitiveSizeInBits(); + const Type *SrcIVTy = VectorType::get(IntegerType::get(FPWidth), + NumSrcElt); + // Ask VMCore to do the conversion now that #elts line up. + C = ConstantExpr::getBitCast(C, SrcIVTy); + CV = dyn_cast<ConstantVector>(C); + if (!CV) return 0; // If VMCore wasn't able to fold it, bail out. + } + + // Now we know that the input and output vectors are both integer vectors + // of the same size, and that their #elements is not the same. Do the + // conversion here, which depends on whether the input or output has + // more elements. + bool isLittleEndian = TD.isLittleEndian(); + + SmallVector<Constant*, 32> Result; + if (NumDstElt < NumSrcElt) { + // Handle: bitcast (<4 x i32> <i32 0, i32 1, i32 2, i32 3> to <2 x i64>) + Constant *Zero = Constant::getNullValue(DstEltTy); + unsigned Ratio = NumSrcElt/NumDstElt; + unsigned SrcBitSize = SrcEltTy->getPrimitiveSizeInBits(); + unsigned SrcElt = 0; + for (unsigned i = 0; i != NumDstElt; ++i) { + // Build each element of the result. + Constant *Elt = Zero; + unsigned ShiftAmt = isLittleEndian ? 0 : SrcBitSize*(Ratio-1); + for (unsigned j = 0; j != Ratio; ++j) { + Constant *Src = dyn_cast<ConstantInt>(CV->getOperand(SrcElt++)); + if (!Src) return 0; // Reject constantexpr elements. + + // Zero extend the element to the right size. + Src = ConstantExpr::getZExt(Src, Elt->getType()); + + // Shift it to the right place, depending on endianness. + Src = ConstantExpr::getShl(Src, + ConstantInt::get(Src->getType(), ShiftAmt)); + ShiftAmt += isLittleEndian ? SrcBitSize : -SrcBitSize; + + // Mix it in. + Elt = ConstantExpr::getOr(Elt, Src); + } + Result.push_back(Elt); + } + } else { + // Handle: bitcast (<2 x i64> <i64 0, i64 1> to <4 x i32>) + unsigned Ratio = NumDstElt/NumSrcElt; + unsigned DstBitSize = DstEltTy->getPrimitiveSizeInBits(); + + // Loop over each source value, expanding into multiple results. + for (unsigned i = 0; i != NumSrcElt; ++i) { + Constant *Src = dyn_cast<ConstantInt>(CV->getOperand(i)); + if (!Src) return 0; // Reject constantexpr elements. + + unsigned ShiftAmt = isLittleEndian ? 0 : DstBitSize*(Ratio-1); + for (unsigned j = 0; j != Ratio; ++j) { + // Shift the piece of the value into the right place, depending on + // endianness. + Constant *Elt = ConstantExpr::getLShr(Src, + ConstantInt::get(Src->getType(), ShiftAmt)); + ShiftAmt += isLittleEndian ? DstBitSize : -DstBitSize; + + // Truncate and remember this piece. + Result.push_back(ConstantExpr::getTrunc(Elt, DstEltTy)); + } + } + } + + return ConstantVector::get(Result.data(), Result.size()); + } + } + + return 0; +} + + +//===----------------------------------------------------------------------===// +// Constant Folding public APIs +//===----------------------------------------------------------------------===// + + +/// ConstantFoldInstruction - Attempt to constant fold the specified +/// instruction. If successful, the constant result is returned, if not, null +/// is returned. Note that this function can only fail when attempting to fold +/// instructions like loads and stores, which have no constant expression form. +/// +Constant *llvm::ConstantFoldInstruction(Instruction *I, const TargetData *TD) { + if (PHINode *PN = dyn_cast<PHINode>(I)) { + if (PN->getNumIncomingValues() == 0) + return UndefValue::get(PN->getType()); + + Constant *Result = dyn_cast<Constant>(PN->getIncomingValue(0)); + if (Result == 0) return 0; + + // Handle PHI nodes specially here... + for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) != Result && PN->getIncomingValue(i) != PN) + return 0; // Not all the same incoming constants... + + // If we reach here, all incoming values are the same constant. + return Result; + } + + // Scan the operand list, checking to see if they are all constants, if so, + // hand off to ConstantFoldInstOperands. + SmallVector<Constant*, 8> Ops; + for (User::op_iterator i = I->op_begin(), e = I->op_end(); i != e; ++i) + if (Constant *Op = dyn_cast<Constant>(*i)) + Ops.push_back(Op); + else + return 0; // All operands not constant! + + if (const CmpInst *CI = dyn_cast<CmpInst>(I)) + return ConstantFoldCompareInstOperands(CI->getPredicate(), + Ops.data(), Ops.size(), TD); + else + return ConstantFoldInstOperands(I->getOpcode(), I->getType(), + Ops.data(), Ops.size(), TD); +} + +/// ConstantFoldConstantExpression - Attempt to fold the constant expression +/// using the specified TargetData. If successful, the constant result is +/// result is returned, if not, null is returned. +Constant *llvm::ConstantFoldConstantExpression(ConstantExpr *CE, + const TargetData *TD) { + assert(TD && "ConstantFoldConstantExpression requires a valid TargetData."); + + SmallVector<Constant*, 8> Ops; + for (User::op_iterator i = CE->op_begin(), e = CE->op_end(); i != e; ++i) + Ops.push_back(cast<Constant>(*i)); + + if (CE->isCompare()) + return ConstantFoldCompareInstOperands(CE->getPredicate(), + Ops.data(), Ops.size(), TD); + else + return ConstantFoldInstOperands(CE->getOpcode(), CE->getType(), + Ops.data(), Ops.size(), TD); +} + +/// ConstantFoldInstOperands - Attempt to constant fold an instruction with the +/// specified opcode and operands. If successful, the constant result is +/// returned, if not, null is returned. Note that this function can fail when +/// attempting to fold instructions like loads and stores, which have no +/// constant expression form. +/// +Constant *llvm::ConstantFoldInstOperands(unsigned Opcode, const Type *DestTy, + Constant* const* Ops, unsigned NumOps, + const TargetData *TD) { + // Handle easy binops first. + if (Instruction::isBinaryOp(Opcode)) { + if (isa<ConstantExpr>(Ops[0]) || isa<ConstantExpr>(Ops[1])) + if (Constant *C = SymbolicallyEvaluateBinop(Opcode, Ops[0], Ops[1], TD)) + return C; + + return ConstantExpr::get(Opcode, Ops[0], Ops[1]); + } + + switch (Opcode) { + default: return 0; + case Instruction::Call: + if (Function *F = dyn_cast<Function>(Ops[0])) + if (canConstantFoldCallTo(F)) + return ConstantFoldCall(F, Ops+1, NumOps-1); + return 0; + case Instruction::ICmp: + case Instruction::FCmp: + case Instruction::VICmp: + case Instruction::VFCmp: + assert(0 &&"This function is invalid for compares: no predicate specified"); + case Instruction::PtrToInt: + // If the input is a inttoptr, eliminate the pair. This requires knowing + // the width of a pointer, so it can't be done in ConstantExpr::getCast. + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ops[0])) { + if (TD && CE->getOpcode() == Instruction::IntToPtr) { + Constant *Input = CE->getOperand(0); + unsigned InWidth = Input->getType()->getPrimitiveSizeInBits(); + if (TD->getPointerSizeInBits() < InWidth) { + Constant *Mask = + ConstantInt::get(APInt::getLowBitsSet(InWidth, + TD->getPointerSizeInBits())); + Input = ConstantExpr::getAnd(Input, Mask); + } + // Do a zext or trunc to get to the dest size. + return ConstantExpr::getIntegerCast(Input, DestTy, false); + } + } + return ConstantExpr::getCast(Opcode, Ops[0], DestTy); + case Instruction::IntToPtr: + // If the input is a ptrtoint, turn the pair into a ptr to ptr bitcast if + // the int size is >= the ptr size. This requires knowing the width of a + // pointer, so it can't be done in ConstantExpr::getCast. + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ops[0])) { + if (TD && + TD->getPointerSizeInBits() <= + CE->getType()->getPrimitiveSizeInBits()) { + if (CE->getOpcode() == Instruction::PtrToInt) { + Constant *Input = CE->getOperand(0); + Constant *C = FoldBitCast(Input, DestTy, *TD); + return C ? C : ConstantExpr::getBitCast(Input, DestTy); + } + // If there's a constant offset added to the integer value before + // it is casted back to a pointer, see if the expression can be + // converted into a GEP. + if (CE->getOpcode() == Instruction::Add) + if (ConstantInt *L = dyn_cast<ConstantInt>(CE->getOperand(0))) + if (ConstantExpr *R = dyn_cast<ConstantExpr>(CE->getOperand(1))) + if (R->getOpcode() == Instruction::PtrToInt) + if (GlobalVariable *GV = + dyn_cast<GlobalVariable>(R->getOperand(0))) { + const PointerType *GVTy = cast<PointerType>(GV->getType()); + if (const ArrayType *AT = + dyn_cast<ArrayType>(GVTy->getElementType())) { + const Type *ElTy = AT->getElementType(); + uint64_t AllocSize = TD->getTypeAllocSize(ElTy); + APInt PSA(L->getValue().getBitWidth(), AllocSize); + if (ElTy == cast<PointerType>(DestTy)->getElementType() && + L->getValue().urem(PSA) == 0) { + APInt ElemIdx = L->getValue().udiv(PSA); + if (ElemIdx.ult(APInt(ElemIdx.getBitWidth(), + AT->getNumElements()))) { + Constant *Index[] = { + Constant::getNullValue(CE->getType()), + ConstantInt::get(ElemIdx) + }; + return ConstantExpr::getGetElementPtr(GV, &Index[0], 2); + } + } + } + } + } + } + return ConstantExpr::getCast(Opcode, Ops[0], DestTy); + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPToUI: + case Instruction::FPToSI: + return ConstantExpr::getCast(Opcode, Ops[0], DestTy); + case Instruction::BitCast: + if (TD) + if (Constant *C = FoldBitCast(Ops[0], DestTy, *TD)) + return C; + return ConstantExpr::getBitCast(Ops[0], DestTy); + case Instruction::Select: + return ConstantExpr::getSelect(Ops[0], Ops[1], Ops[2]); + case Instruction::ExtractElement: + return ConstantExpr::getExtractElement(Ops[0], Ops[1]); + case Instruction::InsertElement: + return ConstantExpr::getInsertElement(Ops[0], Ops[1], Ops[2]); + case Instruction::ShuffleVector: + return ConstantExpr::getShuffleVector(Ops[0], Ops[1], Ops[2]); + case Instruction::GetElementPtr: + if (Constant *C = SymbolicallyEvaluateGEP(Ops, NumOps, DestTy, TD)) + return C; + + return ConstantExpr::getGetElementPtr(Ops[0], Ops+1, NumOps-1); + } +} + +/// ConstantFoldCompareInstOperands - Attempt to constant fold a compare +/// instruction (icmp/fcmp) with the specified operands. If it fails, it +/// returns a constant expression of the specified operands. +/// +Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate, + Constant*const * Ops, + unsigned NumOps, + const TargetData *TD) { + // fold: icmp (inttoptr x), null -> icmp x, 0 + // fold: icmp (ptrtoint x), 0 -> icmp x, null + // fold: icmp (inttoptr x), (inttoptr y) -> icmp trunc/zext x, trunc/zext y + // fold: icmp (ptrtoint x), (ptrtoint y) -> icmp x, y + // + // ConstantExpr::getCompare cannot do this, because it doesn't have TD + // around to know if bit truncation is happening. + if (ConstantExpr *CE0 = dyn_cast<ConstantExpr>(Ops[0])) { + if (TD && Ops[1]->isNullValue()) { + const Type *IntPtrTy = TD->getIntPtrType(); + if (CE0->getOpcode() == Instruction::IntToPtr) { + // Convert the integer value to the right size to ensure we get the + // proper extension or truncation. + Constant *C = ConstantExpr::getIntegerCast(CE0->getOperand(0), + IntPtrTy, false); + Constant *NewOps[] = { C, Constant::getNullValue(C->getType()) }; + return ConstantFoldCompareInstOperands(Predicate, NewOps, 2, TD); + } + + // Only do this transformation if the int is intptrty in size, otherwise + // there is a truncation or extension that we aren't modeling. + if (CE0->getOpcode() == Instruction::PtrToInt && + CE0->getType() == IntPtrTy) { + Constant *C = CE0->getOperand(0); + Constant *NewOps[] = { C, Constant::getNullValue(C->getType()) }; + // FIXME! + return ConstantFoldCompareInstOperands(Predicate, NewOps, 2, TD); + } + } + + if (ConstantExpr *CE1 = dyn_cast<ConstantExpr>(Ops[1])) { + if (TD && CE0->getOpcode() == CE1->getOpcode()) { + const Type *IntPtrTy = TD->getIntPtrType(); + + if (CE0->getOpcode() == Instruction::IntToPtr) { + // Convert the integer value to the right size to ensure we get the + // proper extension or truncation. + Constant *C0 = ConstantExpr::getIntegerCast(CE0->getOperand(0), + IntPtrTy, false); + Constant *C1 = ConstantExpr::getIntegerCast(CE1->getOperand(0), + IntPtrTy, false); + Constant *NewOps[] = { C0, C1 }; + return ConstantFoldCompareInstOperands(Predicate, NewOps, 2, TD); + } + + // Only do this transformation if the int is intptrty in size, otherwise + // there is a truncation or extension that we aren't modeling. + if ((CE0->getOpcode() == Instruction::PtrToInt && + CE0->getType() == IntPtrTy && + CE0->getOperand(0)->getType() == CE1->getOperand(0)->getType())) { + Constant *NewOps[] = { + CE0->getOperand(0), CE1->getOperand(0) + }; + return ConstantFoldCompareInstOperands(Predicate, NewOps, 2, TD); + } + } + } + } + return ConstantExpr::getCompare(Predicate, Ops[0], Ops[1]); +} + + +/// ConstantFoldLoadThroughGEPConstantExpr - Given a constant and a +/// getelementptr constantexpr, return the constant value being addressed by the +/// constant expression, or null if something is funny and we can't decide. +Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C, + ConstantExpr *CE) { + if (CE->getOperand(1) != Constant::getNullValue(CE->getOperand(1)->getType())) + return 0; // Do not allow stepping over the value! + + // Loop over all of the operands, tracking down which value we are + // addressing... + gep_type_iterator I = gep_type_begin(CE), E = gep_type_end(CE); + for (++I; I != E; ++I) + if (const StructType *STy = dyn_cast<StructType>(*I)) { + ConstantInt *CU = cast<ConstantInt>(I.getOperand()); + assert(CU->getZExtValue() < STy->getNumElements() && + "Struct index out of range!"); + unsigned El = (unsigned)CU->getZExtValue(); + if (ConstantStruct *CS = dyn_cast<ConstantStruct>(C)) { + C = CS->getOperand(El); + } else if (isa<ConstantAggregateZero>(C)) { + C = Constant::getNullValue(STy->getElementType(El)); + } else if (isa<UndefValue>(C)) { + C = UndefValue::get(STy->getElementType(El)); + } else { + return 0; + } + } else if (ConstantInt *CI = dyn_cast<ConstantInt>(I.getOperand())) { + if (const ArrayType *ATy = dyn_cast<ArrayType>(*I)) { + if (CI->getZExtValue() >= ATy->getNumElements()) + return 0; + if (ConstantArray *CA = dyn_cast<ConstantArray>(C)) + C = CA->getOperand(CI->getZExtValue()); + else if (isa<ConstantAggregateZero>(C)) + C = Constant::getNullValue(ATy->getElementType()); + else if (isa<UndefValue>(C)) + C = UndefValue::get(ATy->getElementType()); + else + return 0; + } else if (const VectorType *PTy = dyn_cast<VectorType>(*I)) { + if (CI->getZExtValue() >= PTy->getNumElements()) + return 0; + if (ConstantVector *CP = dyn_cast<ConstantVector>(C)) + C = CP->getOperand(CI->getZExtValue()); + else if (isa<ConstantAggregateZero>(C)) + C = Constant::getNullValue(PTy->getElementType()); + else if (isa<UndefValue>(C)) + C = UndefValue::get(PTy->getElementType()); + else + return 0; + } else { + return 0; + } + } else { + return 0; + } + return C; +} + + +//===----------------------------------------------------------------------===// +// Constant Folding for Calls +// + +/// canConstantFoldCallTo - Return true if its even possible to fold a call to +/// the specified function. +bool +llvm::canConstantFoldCallTo(const Function *F) { + switch (F->getIntrinsicID()) { + case Intrinsic::sqrt: + case Intrinsic::powi: + case Intrinsic::bswap: + case Intrinsic::ctpop: + case Intrinsic::ctlz: + case Intrinsic::cttz: + return true; + default: break; + } + + if (!F->hasName()) return false; + const char *Str = F->getNameStart(); + unsigned Len = F->getNameLen(); + + // In these cases, the check of the length is required. We don't want to + // return true for a name like "cos\0blah" which strcmp would return equal to + // "cos", but has length 8. + switch (Str[0]) { + default: return false; + case 'a': + if (Len == 4) + return !strcmp(Str, "acos") || !strcmp(Str, "asin") || + !strcmp(Str, "atan"); + else if (Len == 5) + return !strcmp(Str, "atan2"); + return false; + case 'c': + if (Len == 3) + return !strcmp(Str, "cos"); + else if (Len == 4) + return !strcmp(Str, "ceil") || !strcmp(Str, "cosf") || + !strcmp(Str, "cosh"); + return false; + case 'e': + if (Len == 3) + return !strcmp(Str, "exp"); + return false; + case 'f': + if (Len == 4) + return !strcmp(Str, "fabs") || !strcmp(Str, "fmod"); + else if (Len == 5) + return !strcmp(Str, "floor"); + return false; + break; + case 'l': + if (Len == 3 && !strcmp(Str, "log")) + return true; + if (Len == 5 && !strcmp(Str, "log10")) + return true; + return false; + case 'p': + if (Len == 3 && !strcmp(Str, "pow")) + return true; + return false; + case 's': + if (Len == 3) + return !strcmp(Str, "sin"); + if (Len == 4) + return !strcmp(Str, "sinh") || !strcmp(Str, "sqrt") || + !strcmp(Str, "sinf"); + if (Len == 5) + return !strcmp(Str, "sqrtf"); + return false; + case 't': + if (Len == 3 && !strcmp(Str, "tan")) + return true; + else if (Len == 4 && !strcmp(Str, "tanh")) + return true; + return false; + } +} + +static Constant *ConstantFoldFP(double (*NativeFP)(double), double V, + const Type *Ty) { + errno = 0; + V = NativeFP(V); + if (errno != 0) { + errno = 0; + return 0; + } + + if (Ty == Type::FloatTy) + return ConstantFP::get(APFloat((float)V)); + if (Ty == Type::DoubleTy) + return ConstantFP::get(APFloat(V)); + assert(0 && "Can only constant fold float/double"); + return 0; // dummy return to suppress warning +} + +static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double), + double V, double W, + const Type *Ty) { + errno = 0; + V = NativeFP(V, W); + if (errno != 0) { + errno = 0; + return 0; + } + + if (Ty == Type::FloatTy) + return ConstantFP::get(APFloat((float)V)); + if (Ty == Type::DoubleTy) + return ConstantFP::get(APFloat(V)); + assert(0 && "Can only constant fold float/double"); + return 0; // dummy return to suppress warning +} + +/// ConstantFoldCall - Attempt to constant fold a call to the specified function +/// with the specified arguments, returning null if unsuccessful. + +Constant * +llvm::ConstantFoldCall(Function *F, + Constant* const* Operands, unsigned NumOperands) { + if (!F->hasName()) return 0; + const char *Str = F->getNameStart(); + unsigned Len = F->getNameLen(); + + const Type *Ty = F->getReturnType(); + if (NumOperands == 1) { + if (ConstantFP *Op = dyn_cast<ConstantFP>(Operands[0])) { + if (Ty!=Type::FloatTy && Ty!=Type::DoubleTy) + return 0; + /// Currently APFloat versions of these functions do not exist, so we use + /// the host native double versions. Float versions are not called + /// directly but for all these it is true (float)(f((double)arg)) == + /// f(arg). Long double not supported yet. + double V = Ty==Type::FloatTy ? (double)Op->getValueAPF().convertToFloat(): + Op->getValueAPF().convertToDouble(); + switch (Str[0]) { + case 'a': + if (Len == 4 && !strcmp(Str, "acos")) + return ConstantFoldFP(acos, V, Ty); + else if (Len == 4 && !strcmp(Str, "asin")) + return ConstantFoldFP(asin, V, Ty); + else if (Len == 4 && !strcmp(Str, "atan")) + return ConstantFoldFP(atan, V, Ty); + break; + case 'c': + if (Len == 4 && !strcmp(Str, "ceil")) + return ConstantFoldFP(ceil, V, Ty); + else if (Len == 3 && !strcmp(Str, "cos")) + return ConstantFoldFP(cos, V, Ty); + else if (Len == 4 && !strcmp(Str, "cosh")) + return ConstantFoldFP(cosh, V, Ty); + else if (Len == 4 && !strcmp(Str, "cosf")) + return ConstantFoldFP(cos, V, Ty); + break; + case 'e': + if (Len == 3 && !strcmp(Str, "exp")) + return ConstantFoldFP(exp, V, Ty); + break; + case 'f': + if (Len == 4 && !strcmp(Str, "fabs")) + return ConstantFoldFP(fabs, V, Ty); + else if (Len == 5 && !strcmp(Str, "floor")) + return ConstantFoldFP(floor, V, Ty); + break; + case 'l': + if (Len == 3 && !strcmp(Str, "log") && V > 0) + return ConstantFoldFP(log, V, Ty); + else if (Len == 5 && !strcmp(Str, "log10") && V > 0) + return ConstantFoldFP(log10, V, Ty); + else if (!strcmp(Str, "llvm.sqrt.f32") || + !strcmp(Str, "llvm.sqrt.f64")) { + if (V >= -0.0) + return ConstantFoldFP(sqrt, V, Ty); + else // Undefined + return Constant::getNullValue(Ty); + } + break; + case 's': + if (Len == 3 && !strcmp(Str, "sin")) + return ConstantFoldFP(sin, V, Ty); + else if (Len == 4 && !strcmp(Str, "sinh")) + return ConstantFoldFP(sinh, V, Ty); + else if (Len == 4 && !strcmp(Str, "sqrt") && V >= 0) + return ConstantFoldFP(sqrt, V, Ty); + else if (Len == 5 && !strcmp(Str, "sqrtf") && V >= 0) + return ConstantFoldFP(sqrt, V, Ty); + else if (Len == 4 && !strcmp(Str, "sinf")) + return ConstantFoldFP(sin, V, Ty); + break; + case 't': + if (Len == 3 && !strcmp(Str, "tan")) + return ConstantFoldFP(tan, V, Ty); + else if (Len == 4 && !strcmp(Str, "tanh")) + return ConstantFoldFP(tanh, V, Ty); + break; + default: + break; + } + } else if (ConstantInt *Op = dyn_cast<ConstantInt>(Operands[0])) { + if (Len > 11 && !memcmp(Str, "llvm.bswap", 10)) + return ConstantInt::get(Op->getValue().byteSwap()); + else if (Len > 11 && !memcmp(Str, "llvm.ctpop", 10)) + return ConstantInt::get(Ty, Op->getValue().countPopulation()); + else if (Len > 10 && !memcmp(Str, "llvm.cttz", 9)) + return ConstantInt::get(Ty, Op->getValue().countTrailingZeros()); + else if (Len > 10 && !memcmp(Str, "llvm.ctlz", 9)) + return ConstantInt::get(Ty, Op->getValue().countLeadingZeros()); + } + } else if (NumOperands == 2) { + if (ConstantFP *Op1 = dyn_cast<ConstantFP>(Operands[0])) { + if (Ty!=Type::FloatTy && Ty!=Type::DoubleTy) + return 0; + double Op1V = Ty==Type::FloatTy ? + (double)Op1->getValueAPF().convertToFloat(): + Op1->getValueAPF().convertToDouble(); + if (ConstantFP *Op2 = dyn_cast<ConstantFP>(Operands[1])) { + double Op2V = Ty==Type::FloatTy ? + (double)Op2->getValueAPF().convertToFloat(): + Op2->getValueAPF().convertToDouble(); + + if (Len == 3 && !strcmp(Str, "pow")) { + return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty); + } else if (Len == 4 && !strcmp(Str, "fmod")) { + return ConstantFoldBinaryFP(fmod, Op1V, Op2V, Ty); + } else if (Len == 5 && !strcmp(Str, "atan2")) { + return ConstantFoldBinaryFP(atan2, Op1V, Op2V, Ty); + } + } else if (ConstantInt *Op2C = dyn_cast<ConstantInt>(Operands[1])) { + if (!strcmp(Str, "llvm.powi.f32")) { + return ConstantFP::get(APFloat((float)std::pow((float)Op1V, + (int)Op2C->getZExtValue()))); + } else if (!strcmp(Str, "llvm.powi.f64")) { + return ConstantFP::get(APFloat((double)std::pow((double)Op1V, + (int)Op2C->getZExtValue()))); + } + } + } + } + return 0; +} + diff --git a/lib/Analysis/DbgInfoPrinter.cpp b/lib/Analysis/DbgInfoPrinter.cpp new file mode 100644 index 0000000..d80d581 --- /dev/null +++ b/lib/Analysis/DbgInfoPrinter.cpp @@ -0,0 +1,167 @@ +//===- DbgInfoPrinter.cpp - Print debug info in a human readable form ------==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass that prints instructions, and associated debug +// info: +// +// - source/line/col information +// - original variable name +// - original type name +// +//===----------------------------------------------------------------------===// + +#include "llvm/Pass.h" +#include "llvm/Function.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Analysis/DebugInfo.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +static cl::opt<bool> +PrintDirectory("print-fullpath", + cl::desc("Print fullpath when printing debug info"), + cl::Hidden); + +namespace { + class VISIBILITY_HIDDEN PrintDbgInfo : public FunctionPass { + raw_ostream &Out; + void printStopPoint(const DbgStopPointInst *DSI); + void printFuncStart(const DbgFuncStartInst *FS); + void printVariableDeclaration(const Value *V); + public: + static char ID; // Pass identification + PrintDbgInfo() : FunctionPass(&ID), Out(outs()) {} + + virtual bool runOnFunction(Function &F); + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + } + }; + char PrintDbgInfo::ID = 0; + static RegisterPass<PrintDbgInfo> X("print-dbginfo", + "Print debug info in human readable form"); +} + +FunctionPass *llvm::createDbgInfoPrinterPass() { return new PrintDbgInfo(); } + +void PrintDbgInfo::printVariableDeclaration(const Value *V) { + std::string DisplayName, File, Directory, Type; + unsigned LineNo; + + if (!getLocationInfo(V, DisplayName, Type, LineNo, File, Directory)) + return; + + Out << "; "; + WriteAsOperand(Out, V, false, 0); + Out << " is variable " << DisplayName + << " of type " << Type << " declared at "; + + if (PrintDirectory) + Out << Directory << "/"; + + Out << File << ":" << LineNo << "\n"; +} + +void PrintDbgInfo::printStopPoint(const DbgStopPointInst *DSI) { + if (PrintDirectory) { + std::string dir; + GetConstantStringInfo(DSI->getDirectory(), dir); + Out << dir << "/"; + } + + std::string file; + GetConstantStringInfo(DSI->getFileName(), file); + Out << file << ":" << DSI->getLine(); + + if (unsigned Col = DSI->getColumn()) + Out << ":" << Col; +} + +void PrintDbgInfo::printFuncStart(const DbgFuncStartInst *FS) { + DISubprogram Subprogram(cast<GlobalVariable>(FS->getSubprogram())); + std::string Res1, Res2; + Out << "; fully qualified function name: " << Subprogram.getDisplayName(Res1) + << " return type: " << Subprogram.getType().getName(Res2) + << " at line " << Subprogram.getLineNumber() + << "\n\n"; +} + +bool PrintDbgInfo::runOnFunction(Function &F) { + if (F.isDeclaration()) + return false; + + Out << "function " << F.getName() << "\n\n"; + + for (Function::iterator I = F.begin(), E = F.end(); I != E; ++I) { + BasicBlock *BB = I; + + if (I != F.begin() && (pred_begin(BB) == pred_end(BB))) + // Skip dead blocks. + continue; + + const DbgStopPointInst *DSI = findBBStopPoint(BB); + Out << BB->getName(); + Out << ":"; + + if (DSI) { + Out << "; ("; + printStopPoint(DSI); + Out << ")"; + } + + Out << "\n"; + + // A dbgstoppoint's information is valid until we encounter a new one. + const DbgStopPointInst *LastDSP = DSI; + bool Printed = DSI != 0; + for (BasicBlock::const_iterator i = BB->begin(), e = BB->end(); + i != e; ++i) { + if (isa<DbgInfoIntrinsic>(i)) { + if ((DSI = dyn_cast<DbgStopPointInst>(i))) { + if (DSI->getContext() == LastDSP->getContext() && + DSI->getLineValue() == LastDSP->getLineValue() && + DSI->getColumnValue() == LastDSP->getColumnValue()) + // Don't print same location twice. + continue; + + LastDSP = cast<DbgStopPointInst>(i); + + // Don't print consecutive stoppoints, use a flag to know which one we + // printed. + Printed = false; + } else if (const DbgFuncStartInst *FS = dyn_cast<DbgFuncStartInst>(i)) { + printFuncStart(FS); + } + } else { + if (!Printed && LastDSP) { + Out << "; "; + printStopPoint(LastDSP); + Out << "\n"; + Printed = true; + } + + Out << *i; + printVariableDeclaration(i); + + if (const User *U = dyn_cast<User>(i)) { + for(unsigned i=0;i<U->getNumOperands();i++) + printVariableDeclaration(U->getOperand(i)); + } + } + } + } + + return false; +} diff --git a/lib/Analysis/DebugInfo.cpp b/lib/Analysis/DebugInfo.cpp new file mode 100644 index 0000000..6bdb64c --- /dev/null +++ b/lib/Analysis/DebugInfo.cpp @@ -0,0 +1,1079 @@ +//===--- DebugInfo.cpp - Debug Information Helper Classes -----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the helper classes used to build and interpret debug +// information in LLVM IR form. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DebugInfo.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Intrinsics.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Instructions.h" +#include "llvm/Module.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Support/Dwarf.h" +#include "llvm/Support/Streams.h" + +using namespace llvm; +using namespace llvm::dwarf; + +//===----------------------------------------------------------------------===// +// DIDescriptor +//===----------------------------------------------------------------------===// + +/// ValidDebugInfo - Return true if V represents valid debug info value. +bool DIDescriptor::ValidDebugInfo(Value *V, CodeGenOpt::Level OptLevel) { + if (!V) + return false; + + GlobalVariable *GV = dyn_cast<GlobalVariable>(V->stripPointerCasts()); + if (!GV) + return false; + + if (!GV->hasInternalLinkage () && !GV->hasLinkOnceLinkage()) + return false; + + DIDescriptor DI(GV); + + // Check current version. Allow Version6 for now. + unsigned Version = DI.getVersion(); + if (Version != LLVMDebugVersion && Version != LLVMDebugVersion6) + return false; + + unsigned Tag = DI.getTag(); + switch (Tag) { + case DW_TAG_variable: + assert(DIVariable(GV).Verify() && "Invalid DebugInfo value"); + break; + case DW_TAG_compile_unit: + assert(DICompileUnit(GV).Verify() && "Invalid DebugInfo value"); + break; + case DW_TAG_subprogram: + assert(DISubprogram(GV).Verify() && "Invalid DebugInfo value"); + break; + case DW_TAG_lexical_block: + // FIXME: This interfers with the quality of generated code during + // optimization. + if (OptLevel != CodeGenOpt::None) + return false; + // FALLTHROUGH + default: + break; + } + + return true; +} + +DIDescriptor::DIDescriptor(GlobalVariable *gv, unsigned RequiredTag) { + GV = gv; + + // If this is non-null, check to see if the Tag matches. If not, set to null. + if (GV && getTag() != RequiredTag) + GV = 0; +} + +const std::string & +DIDescriptor::getStringField(unsigned Elt, std::string &Result) const { + if (GV == 0) { + Result.clear(); + return Result; + } + + Constant *C = GV->getInitializer(); + if (C == 0 || Elt >= C->getNumOperands()) { + Result.clear(); + return Result; + } + + // Fills in the string if it succeeds + if (!GetConstantStringInfo(C->getOperand(Elt), Result)) + Result.clear(); + + return Result; +} + +uint64_t DIDescriptor::getUInt64Field(unsigned Elt) const { + if (GV == 0) return 0; + + Constant *C = GV->getInitializer(); + if (C == 0 || Elt >= C->getNumOperands()) + return 0; + + if (ConstantInt *CI = dyn_cast<ConstantInt>(C->getOperand(Elt))) + return CI->getZExtValue(); + return 0; +} + +DIDescriptor DIDescriptor::getDescriptorField(unsigned Elt) const { + if (GV == 0) return DIDescriptor(); + + Constant *C = GV->getInitializer(); + if (C == 0 || Elt >= C->getNumOperands()) + return DIDescriptor(); + + C = C->getOperand(Elt); + return DIDescriptor(dyn_cast<GlobalVariable>(C->stripPointerCasts())); +} + +GlobalVariable *DIDescriptor::getGlobalVariableField(unsigned Elt) const { + if (GV == 0) return 0; + + Constant *C = GV->getInitializer(); + if (C == 0 || Elt >= C->getNumOperands()) + return 0; + + C = C->getOperand(Elt); + return dyn_cast<GlobalVariable>(C->stripPointerCasts()); +} + +//===----------------------------------------------------------------------===// +// Simple Descriptor Constructors and other Methods +//===----------------------------------------------------------------------===// + +// Needed by DIVariable::getType(). +DIType::DIType(GlobalVariable *gv) : DIDescriptor(gv) { + if (!gv) return; + unsigned tag = getTag(); + if (tag != dwarf::DW_TAG_base_type && !DIDerivedType::isDerivedType(tag) && + !DICompositeType::isCompositeType(tag)) + GV = 0; +} + +/// isDerivedType - Return true if the specified tag is legal for +/// DIDerivedType. +bool DIType::isDerivedType(unsigned Tag) { + switch (Tag) { + case dwarf::DW_TAG_typedef: + case dwarf::DW_TAG_pointer_type: + case dwarf::DW_TAG_reference_type: + case dwarf::DW_TAG_const_type: + case dwarf::DW_TAG_volatile_type: + case dwarf::DW_TAG_restrict_type: + case dwarf::DW_TAG_member: + case dwarf::DW_TAG_inheritance: + return true; + default: + // FIXME: Even though it doesn't make sense, CompositeTypes are current + // modelled as DerivedTypes, this should return true for them as well. + return false; + } +} + +/// isCompositeType - Return true if the specified tag is legal for +/// DICompositeType. +bool DIType::isCompositeType(unsigned TAG) { + switch (TAG) { + case dwarf::DW_TAG_array_type: + case dwarf::DW_TAG_structure_type: + case dwarf::DW_TAG_union_type: + case dwarf::DW_TAG_enumeration_type: + case dwarf::DW_TAG_vector_type: + case dwarf::DW_TAG_subroutine_type: + case dwarf::DW_TAG_class_type: + return true; + default: + return false; + } +} + +/// isVariable - Return true if the specified tag is legal for DIVariable. +bool DIVariable::isVariable(unsigned Tag) { + switch (Tag) { + case dwarf::DW_TAG_auto_variable: + case dwarf::DW_TAG_arg_variable: + case dwarf::DW_TAG_return_variable: + return true; + default: + return false; + } +} + +unsigned DIArray::getNumElements() const { + assert (GV && "Invalid DIArray"); + Constant *C = GV->getInitializer(); + assert (C && "Invalid DIArray initializer"); + return C->getNumOperands(); +} + +/// Verify - Verify that a compile unit is well formed. +bool DICompileUnit::Verify() const { + if (isNull()) + return false; + std::string Res; + if (getFilename(Res).empty()) + return false; + // It is possible that directory and produce string is empty. + return true; +} + +/// Verify - Verify that a type descriptor is well formed. +bool DIType::Verify() const { + if (isNull()) + return false; + if (getContext().isNull()) + return false; + + DICompileUnit CU = getCompileUnit(); + if (!CU.isNull() && !CU.Verify()) + return false; + return true; +} + +/// Verify - Verify that a composite type descriptor is well formed. +bool DICompositeType::Verify() const { + if (isNull()) + return false; + if (getContext().isNull()) + return false; + + DICompileUnit CU = getCompileUnit(); + if (!CU.isNull() && !CU.Verify()) + return false; + return true; +} + +/// Verify - Verify that a subprogram descriptor is well formed. +bool DISubprogram::Verify() const { + if (isNull()) + return false; + + if (getContext().isNull()) + return false; + + DICompileUnit CU = getCompileUnit(); + if (!CU.Verify()) + return false; + + DICompositeType Ty = getType(); + if (!Ty.isNull() && !Ty.Verify()) + return false; + return true; +} + +/// Verify - Verify that a global variable descriptor is well formed. +bool DIGlobalVariable::Verify() const { + if (isNull()) + return false; + + if (getContext().isNull()) + return false; + + DICompileUnit CU = getCompileUnit(); + if (!CU.isNull() && !CU.Verify()) + return false; + + DIType Ty = getType(); + if (!Ty.Verify()) + return false; + + if (!getGlobal()) + return false; + + return true; +} + +/// Verify - Verify that a variable descriptor is well formed. +bool DIVariable::Verify() const { + if (isNull()) + return false; + + if (getContext().isNull()) + return false; + + DIType Ty = getType(); + if (!Ty.Verify()) + return false; + + return true; +} + +/// getOriginalTypeSize - If this type is derived from a base type then +/// return base type size. +uint64_t DIDerivedType::getOriginalTypeSize() const { + if (getTag() != dwarf::DW_TAG_member) + return getSizeInBits(); + DIType BT = getTypeDerivedFrom(); + if (BT.getTag() != dwarf::DW_TAG_base_type) + return getSizeInBits(); + return BT.getSizeInBits(); +} + +/// describes - Return true if this subprogram provides debugging +/// information for the function F. +bool DISubprogram::describes(const Function *F) { + assert (F && "Invalid function"); + std::string Name; + getLinkageName(Name); + if (Name.empty()) + getName(Name); + if (!Name.empty() && (strcmp(Name.c_str(), F->getNameStart()) == false)) + return true; + return false; +} + +//===----------------------------------------------------------------------===// +// DIFactory: Basic Helpers +//===----------------------------------------------------------------------===// + +DIFactory::DIFactory(Module &m) + : M(m), StopPointFn(0), FuncStartFn(0), RegionStartFn(0), RegionEndFn(0), + DeclareFn(0) { + EmptyStructPtr = PointerType::getUnqual(StructType::get(NULL, NULL)); +} + +/// getCastToEmpty - Return this descriptor as a Constant* with type '{}*'. +/// This is only valid when the descriptor is non-null. +Constant *DIFactory::getCastToEmpty(DIDescriptor D) { + if (D.isNull()) return Constant::getNullValue(EmptyStructPtr); + return ConstantExpr::getBitCast(D.getGV(), EmptyStructPtr); +} + +Constant *DIFactory::GetTagConstant(unsigned TAG) { + assert((TAG & LLVMDebugVersionMask) == 0 && + "Tag too large for debug encoding!"); + return ConstantInt::get(Type::Int32Ty, TAG | LLVMDebugVersion); +} + +Constant *DIFactory::GetStringConstant(const std::string &String) { + // Check string cache for previous edition. + Constant *&Slot = StringCache[String]; + + // Return Constant if previously defined. + if (Slot) return Slot; + + const PointerType *DestTy = PointerType::getUnqual(Type::Int8Ty); + + // If empty string then use a sbyte* null instead. + if (String.empty()) + return Slot = ConstantPointerNull::get(DestTy); + + // Construct string as an llvm constant. + Constant *ConstStr = ConstantArray::get(String); + + // Otherwise create and return a new string global. + GlobalVariable *StrGV = new GlobalVariable(ConstStr->getType(), true, + GlobalVariable::InternalLinkage, + ConstStr, ".str", &M); + StrGV->setSection("llvm.metadata"); + return Slot = ConstantExpr::getBitCast(StrGV, DestTy); +} + +/// GetOrCreateAnchor - Look up an anchor for the specified tag and name. If it +/// already exists, return it. If not, create a new one and return it. +DIAnchor DIFactory::GetOrCreateAnchor(unsigned TAG, const char *Name) { + const Type *EltTy = StructType::get(Type::Int32Ty, Type::Int32Ty, NULL); + + // Otherwise, create the global or return it if already in the module. + Constant *C = M.getOrInsertGlobal(Name, EltTy); + assert(isa<GlobalVariable>(C) && "Incorrectly typed anchor?"); + GlobalVariable *GV = cast<GlobalVariable>(C); + + // If it has an initializer, it is already in the module. + if (GV->hasInitializer()) + return SubProgramAnchor = DIAnchor(GV); + + GV->setLinkage(GlobalValue::LinkOnceAnyLinkage); + GV->setSection("llvm.metadata"); + GV->setConstant(true); + M.addTypeName("llvm.dbg.anchor.type", EltTy); + + // Otherwise, set the initializer. + Constant *Elts[] = { + GetTagConstant(dwarf::DW_TAG_anchor), + ConstantInt::get(Type::Int32Ty, TAG) + }; + + GV->setInitializer(ConstantStruct::get(Elts, 2)); + return DIAnchor(GV); +} + + + +//===----------------------------------------------------------------------===// +// DIFactory: Primary Constructors +//===----------------------------------------------------------------------===// + +/// GetOrCreateCompileUnitAnchor - Return the anchor for compile units, +/// creating a new one if there isn't already one in the module. +DIAnchor DIFactory::GetOrCreateCompileUnitAnchor() { + // If we already created one, just return it. + if (!CompileUnitAnchor.isNull()) + return CompileUnitAnchor; + return CompileUnitAnchor = GetOrCreateAnchor(dwarf::DW_TAG_compile_unit, + "llvm.dbg.compile_units"); +} + +/// GetOrCreateSubprogramAnchor - Return the anchor for subprograms, +/// creating a new one if there isn't already one in the module. +DIAnchor DIFactory::GetOrCreateSubprogramAnchor() { + // If we already created one, just return it. + if (!SubProgramAnchor.isNull()) + return SubProgramAnchor; + return SubProgramAnchor = GetOrCreateAnchor(dwarf::DW_TAG_subprogram, + "llvm.dbg.subprograms"); +} + +/// GetOrCreateGlobalVariableAnchor - Return the anchor for globals, +/// creating a new one if there isn't already one in the module. +DIAnchor DIFactory::GetOrCreateGlobalVariableAnchor() { + // If we already created one, just return it. + if (!GlobalVariableAnchor.isNull()) + return GlobalVariableAnchor; + return GlobalVariableAnchor = GetOrCreateAnchor(dwarf::DW_TAG_variable, + "llvm.dbg.global_variables"); +} + +/// GetOrCreateArray - Create an descriptor for an array of descriptors. +/// This implicitly uniques the arrays created. +DIArray DIFactory::GetOrCreateArray(DIDescriptor *Tys, unsigned NumTys) { + SmallVector<Constant*, 16> Elts; + + for (unsigned i = 0; i != NumTys; ++i) + Elts.push_back(getCastToEmpty(Tys[i])); + + Constant *Init = ConstantArray::get(ArrayType::get(EmptyStructPtr, + Elts.size()), + Elts.data(), Elts.size()); + // If we already have this array, just return the uniqued version. + DIDescriptor &Entry = SimpleConstantCache[Init]; + if (!Entry.isNull()) return DIArray(Entry.getGV()); + + GlobalVariable *GV = new GlobalVariable(Init->getType(), true, + GlobalValue::InternalLinkage, + Init, "llvm.dbg.array", &M); + GV->setSection("llvm.metadata"); + Entry = DIDescriptor(GV); + return DIArray(GV); +} + +/// GetOrCreateSubrange - Create a descriptor for a value range. This +/// implicitly uniques the values returned. +DISubrange DIFactory::GetOrCreateSubrange(int64_t Lo, int64_t Hi) { + Constant *Elts[] = { + GetTagConstant(dwarf::DW_TAG_subrange_type), + ConstantInt::get(Type::Int64Ty, Lo), + ConstantInt::get(Type::Int64Ty, Hi) + }; + + Constant *Init = ConstantStruct::get(Elts, sizeof(Elts)/sizeof(Elts[0])); + + // If we already have this range, just return the uniqued version. + DIDescriptor &Entry = SimpleConstantCache[Init]; + if (!Entry.isNull()) return DISubrange(Entry.getGV()); + + M.addTypeName("llvm.dbg.subrange.type", Init->getType()); + + GlobalVariable *GV = new GlobalVariable(Init->getType(), true, + GlobalValue::InternalLinkage, + Init, "llvm.dbg.subrange", &M); + GV->setSection("llvm.metadata"); + Entry = DIDescriptor(GV); + return DISubrange(GV); +} + + + +/// CreateCompileUnit - Create a new descriptor for the specified compile +/// unit. Note that this does not unique compile units within the module. +DICompileUnit DIFactory::CreateCompileUnit(unsigned LangID, + const std::string &Filename, + const std::string &Directory, + const std::string &Producer, + bool isMain, + bool isOptimized, + const char *Flags, + unsigned RunTimeVer) { + Constant *Elts[] = { + GetTagConstant(dwarf::DW_TAG_compile_unit), + getCastToEmpty(GetOrCreateCompileUnitAnchor()), + ConstantInt::get(Type::Int32Ty, LangID), + GetStringConstant(Filename), + GetStringConstant(Directory), + GetStringConstant(Producer), + ConstantInt::get(Type::Int1Ty, isMain), + ConstantInt::get(Type::Int1Ty, isOptimized), + GetStringConstant(Flags), + ConstantInt::get(Type::Int32Ty, RunTimeVer) + }; + + Constant *Init = ConstantStruct::get(Elts, sizeof(Elts)/sizeof(Elts[0])); + + M.addTypeName("llvm.dbg.compile_unit.type", Init->getType()); + GlobalVariable *GV = new GlobalVariable(Init->getType(), true, + GlobalValue::InternalLinkage, + Init, "llvm.dbg.compile_unit", &M); + GV->setSection("llvm.metadata"); + return DICompileUnit(GV); +} + +/// CreateEnumerator - Create a single enumerator value. +DIEnumerator DIFactory::CreateEnumerator(const std::string &Name, uint64_t Val){ + Constant *Elts[] = { + GetTagConstant(dwarf::DW_TAG_enumerator), + GetStringConstant(Name), + ConstantInt::get(Type::Int64Ty, Val) + }; + + Constant *Init = ConstantStruct::get(Elts, sizeof(Elts)/sizeof(Elts[0])); + + M.addTypeName("llvm.dbg.enumerator.type", Init->getType()); + GlobalVariable *GV = new GlobalVariable(Init->getType(), true, + GlobalValue::InternalLinkage, + Init, "llvm.dbg.enumerator", &M); + GV->setSection("llvm.metadata"); + return DIEnumerator(GV); +} + + +/// CreateBasicType - Create a basic type like int, float, etc. +DIBasicType DIFactory::CreateBasicType(DIDescriptor Context, + const std::string &Name, + DICompileUnit CompileUnit, + unsigned LineNumber, + uint64_t SizeInBits, + uint64_t AlignInBits, + uint64_t OffsetInBits, unsigned Flags, + unsigned Encoding) { + Constant *Elts[] = { + GetTagConstant(dwarf::DW_TAG_base_type), + getCastToEmpty(Context), + GetStringConstant(Name), + getCastToEmpty(CompileUnit), + ConstantInt::get(Type::Int32Ty, LineNumber), + ConstantInt::get(Type::Int64Ty, SizeInBits), + ConstantInt::get(Type::Int64Ty, AlignInBits), + ConstantInt::get(Type::Int64Ty, OffsetInBits), + ConstantInt::get(Type::Int32Ty, Flags), + ConstantInt::get(Type::Int32Ty, Encoding) + }; + + Constant *Init = ConstantStruct::get(Elts, sizeof(Elts)/sizeof(Elts[0])); + + M.addTypeName("llvm.dbg.basictype.type", Init->getType()); + GlobalVariable *GV = new GlobalVariable(Init->getType(), true, + GlobalValue::InternalLinkage, + Init, "llvm.dbg.basictype", &M); + GV->setSection("llvm.metadata"); + return DIBasicType(GV); +} + +/// CreateDerivedType - Create a derived type like const qualified type, +/// pointer, typedef, etc. +DIDerivedType DIFactory::CreateDerivedType(unsigned Tag, + DIDescriptor Context, + const std::string &Name, + DICompileUnit CompileUnit, + unsigned LineNumber, + uint64_t SizeInBits, + uint64_t AlignInBits, + uint64_t OffsetInBits, + unsigned Flags, + DIType DerivedFrom) { + Constant *Elts[] = { + GetTagConstant(Tag), + getCastToEmpty(Context), + GetStringConstant(Name), + getCastToEmpty(CompileUnit), + ConstantInt::get(Type::Int32Ty, LineNumber), + ConstantInt::get(Type::Int64Ty, SizeInBits), + ConstantInt::get(Type::Int64Ty, AlignInBits), + ConstantInt::get(Type::Int64Ty, OffsetInBits), + ConstantInt::get(Type::Int32Ty, Flags), + getCastToEmpty(DerivedFrom) + }; + + Constant *Init = ConstantStruct::get(Elts, sizeof(Elts)/sizeof(Elts[0])); + + M.addTypeName("llvm.dbg.derivedtype.type", Init->getType()); + GlobalVariable *GV = new GlobalVariable(Init->getType(), true, + GlobalValue::InternalLinkage, + Init, "llvm.dbg.derivedtype", &M); + GV->setSection("llvm.metadata"); + return DIDerivedType(GV); +} + +/// CreateCompositeType - Create a composite type like array, struct, etc. +DICompositeType DIFactory::CreateCompositeType(unsigned Tag, + DIDescriptor Context, + const std::string &Name, + DICompileUnit CompileUnit, + unsigned LineNumber, + uint64_t SizeInBits, + uint64_t AlignInBits, + uint64_t OffsetInBits, + unsigned Flags, + DIType DerivedFrom, + DIArray Elements, + unsigned RuntimeLang) { + + Constant *Elts[] = { + GetTagConstant(Tag), + getCastToEmpty(Context), + GetStringConstant(Name), + getCastToEmpty(CompileUnit), + ConstantInt::get(Type::Int32Ty, LineNumber), + ConstantInt::get(Type::Int64Ty, SizeInBits), + ConstantInt::get(Type::Int64Ty, AlignInBits), + ConstantInt::get(Type::Int64Ty, OffsetInBits), + ConstantInt::get(Type::Int32Ty, Flags), + getCastToEmpty(DerivedFrom), + getCastToEmpty(Elements), + ConstantInt::get(Type::Int32Ty, RuntimeLang) + }; + + Constant *Init = ConstantStruct::get(Elts, sizeof(Elts)/sizeof(Elts[0])); + + M.addTypeName("llvm.dbg.composite.type", Init->getType()); + GlobalVariable *GV = new GlobalVariable(Init->getType(), true, + GlobalValue::InternalLinkage, + Init, "llvm.dbg.composite", &M); + GV->setSection("llvm.metadata"); + return DICompositeType(GV); +} + + +/// CreateSubprogram - Create a new descriptor for the specified subprogram. +/// See comments in DISubprogram for descriptions of these fields. This +/// method does not unique the generated descriptors. +DISubprogram DIFactory::CreateSubprogram(DIDescriptor Context, + const std::string &Name, + const std::string &DisplayName, + const std::string &LinkageName, + DICompileUnit CompileUnit, + unsigned LineNo, DIType Type, + bool isLocalToUnit, + bool isDefinition) { + + Constant *Elts[] = { + GetTagConstant(dwarf::DW_TAG_subprogram), + getCastToEmpty(GetOrCreateSubprogramAnchor()), + getCastToEmpty(Context), + GetStringConstant(Name), + GetStringConstant(DisplayName), + GetStringConstant(LinkageName), + getCastToEmpty(CompileUnit), + ConstantInt::get(Type::Int32Ty, LineNo), + getCastToEmpty(Type), + ConstantInt::get(Type::Int1Ty, isLocalToUnit), + ConstantInt::get(Type::Int1Ty, isDefinition) + }; + + Constant *Init = ConstantStruct::get(Elts, sizeof(Elts)/sizeof(Elts[0])); + + M.addTypeName("llvm.dbg.subprogram.type", Init->getType()); + GlobalVariable *GV = new GlobalVariable(Init->getType(), true, + GlobalValue::InternalLinkage, + Init, "llvm.dbg.subprogram", &M); + GV->setSection("llvm.metadata"); + return DISubprogram(GV); +} + +/// CreateGlobalVariable - Create a new descriptor for the specified global. +DIGlobalVariable +DIFactory::CreateGlobalVariable(DIDescriptor Context, const std::string &Name, + const std::string &DisplayName, + const std::string &LinkageName, + DICompileUnit CompileUnit, + unsigned LineNo, DIType Type,bool isLocalToUnit, + bool isDefinition, llvm::GlobalVariable *Val) { + Constant *Elts[] = { + GetTagConstant(dwarf::DW_TAG_variable), + getCastToEmpty(GetOrCreateGlobalVariableAnchor()), + getCastToEmpty(Context), + GetStringConstant(Name), + GetStringConstant(DisplayName), + GetStringConstant(LinkageName), + getCastToEmpty(CompileUnit), + ConstantInt::get(Type::Int32Ty, LineNo), + getCastToEmpty(Type), + ConstantInt::get(Type::Int1Ty, isLocalToUnit), + ConstantInt::get(Type::Int1Ty, isDefinition), + ConstantExpr::getBitCast(Val, EmptyStructPtr) + }; + + Constant *Init = ConstantStruct::get(Elts, sizeof(Elts)/sizeof(Elts[0])); + + M.addTypeName("llvm.dbg.global_variable.type", Init->getType()); + GlobalVariable *GV = new GlobalVariable(Init->getType(), true, + GlobalValue::InternalLinkage, + Init, "llvm.dbg.global_variable", &M); + GV->setSection("llvm.metadata"); + return DIGlobalVariable(GV); +} + + +/// CreateVariable - Create a new descriptor for the specified variable. +DIVariable DIFactory::CreateVariable(unsigned Tag, DIDescriptor Context, + const std::string &Name, + DICompileUnit CompileUnit, unsigned LineNo, + DIType Type) { + Constant *Elts[] = { + GetTagConstant(Tag), + getCastToEmpty(Context), + GetStringConstant(Name), + getCastToEmpty(CompileUnit), + ConstantInt::get(Type::Int32Ty, LineNo), + getCastToEmpty(Type) + }; + + Constant *Init = ConstantStruct::get(Elts, sizeof(Elts)/sizeof(Elts[0])); + + M.addTypeName("llvm.dbg.variable.type", Init->getType()); + GlobalVariable *GV = new GlobalVariable(Init->getType(), true, + GlobalValue::InternalLinkage, + Init, "llvm.dbg.variable", &M); + GV->setSection("llvm.metadata"); + return DIVariable(GV); +} + + +/// CreateBlock - This creates a descriptor for a lexical block with the +/// specified parent context. +DIBlock DIFactory::CreateBlock(DIDescriptor Context) { + Constant *Elts[] = { + GetTagConstant(dwarf::DW_TAG_lexical_block), + getCastToEmpty(Context) + }; + + Constant *Init = ConstantStruct::get(Elts, sizeof(Elts)/sizeof(Elts[0])); + + M.addTypeName("llvm.dbg.block.type", Init->getType()); + GlobalVariable *GV = new GlobalVariable(Init->getType(), true, + GlobalValue::InternalLinkage, + Init, "llvm.dbg.block", &M); + GV->setSection("llvm.metadata"); + return DIBlock(GV); +} + + +//===----------------------------------------------------------------------===// +// DIFactory: Routines for inserting code into a function +//===----------------------------------------------------------------------===// + +/// InsertStopPoint - Create a new llvm.dbg.stoppoint intrinsic invocation, +/// inserting it at the end of the specified basic block. +void DIFactory::InsertStopPoint(DICompileUnit CU, unsigned LineNo, + unsigned ColNo, BasicBlock *BB) { + + // Lazily construct llvm.dbg.stoppoint function. + if (!StopPointFn) + StopPointFn = llvm::Intrinsic::getDeclaration(&M, + llvm::Intrinsic::dbg_stoppoint); + + // Invoke llvm.dbg.stoppoint + Value *Args[] = { + llvm::ConstantInt::get(llvm::Type::Int32Ty, LineNo), + llvm::ConstantInt::get(llvm::Type::Int32Ty, ColNo), + getCastToEmpty(CU) + }; + CallInst::Create(StopPointFn, Args, Args+3, "", BB); +} + +/// InsertSubprogramStart - Create a new llvm.dbg.func.start intrinsic to +/// mark the start of the specified subprogram. +void DIFactory::InsertSubprogramStart(DISubprogram SP, BasicBlock *BB) { + // Lazily construct llvm.dbg.func.start. + if (!FuncStartFn) + FuncStartFn = Intrinsic::getDeclaration(&M, Intrinsic::dbg_func_start); + + // Call llvm.dbg.func.start which also implicitly sets a stoppoint. + CallInst::Create(FuncStartFn, getCastToEmpty(SP), "", BB); +} + +/// InsertRegionStart - Insert a new llvm.dbg.region.start intrinsic call to +/// mark the start of a region for the specified scoping descriptor. +void DIFactory::InsertRegionStart(DIDescriptor D, BasicBlock *BB) { + // Lazily construct llvm.dbg.region.start function. + if (!RegionStartFn) + RegionStartFn = Intrinsic::getDeclaration(&M, Intrinsic::dbg_region_start); + + // Call llvm.dbg.func.start. + CallInst::Create(RegionStartFn, getCastToEmpty(D), "", BB); +} + +/// InsertRegionEnd - Insert a new llvm.dbg.region.end intrinsic call to +/// mark the end of a region for the specified scoping descriptor. +void DIFactory::InsertRegionEnd(DIDescriptor D, BasicBlock *BB) { + // Lazily construct llvm.dbg.region.end function. + if (!RegionEndFn) + RegionEndFn = Intrinsic::getDeclaration(&M, Intrinsic::dbg_region_end); + + // Call llvm.dbg.region.end. + CallInst::Create(RegionEndFn, getCastToEmpty(D), "", BB); +} + +/// InsertDeclare - Insert a new llvm.dbg.declare intrinsic call. +void DIFactory::InsertDeclare(Value *Storage, DIVariable D, BasicBlock *BB) { + // Cast the storage to a {}* for the call to llvm.dbg.declare. + Storage = new BitCastInst(Storage, EmptyStructPtr, "", BB); + + if (!DeclareFn) + DeclareFn = Intrinsic::getDeclaration(&M, Intrinsic::dbg_declare); + + Value *Args[] = { Storage, getCastToEmpty(D) }; + CallInst::Create(DeclareFn, Args, Args+2, "", BB); +} + +namespace llvm { + /// findStopPoint - Find the stoppoint coressponding to this instruction, that + /// is the stoppoint that dominates this instruction. + const DbgStopPointInst *findStopPoint(const Instruction *Inst) { + if (const DbgStopPointInst *DSI = dyn_cast<DbgStopPointInst>(Inst)) + return DSI; + + const BasicBlock *BB = Inst->getParent(); + BasicBlock::const_iterator I = Inst, B; + while (BB) { + B = BB->begin(); + + // A BB consisting only of a terminator can't have a stoppoint. + while (I != B) { + --I; + if (const DbgStopPointInst *DSI = dyn_cast<DbgStopPointInst>(I)) + return DSI; + } + + // This BB didn't have a stoppoint: if there is only one predecessor, look + // for a stoppoint there. We could use getIDom(), but that would require + // dominator info. + BB = I->getParent()->getUniquePredecessor(); + if (BB) + I = BB->getTerminator(); + } + + return 0; + } + + /// findBBStopPoint - Find the stoppoint corresponding to first real + /// (non-debug intrinsic) instruction in this Basic Block, and return the + /// stoppoint for it. + const DbgStopPointInst *findBBStopPoint(const BasicBlock *BB) { + for(BasicBlock::const_iterator I = BB->begin(), E = BB->end(); I != E; ++I) + if (const DbgStopPointInst *DSI = dyn_cast<DbgStopPointInst>(I)) + return DSI; + + // Fallback to looking for stoppoint of unique predecessor. Useful if this + // BB contains no stoppoints, but unique predecessor does. + BB = BB->getUniquePredecessor(); + if (BB) + return findStopPoint(BB->getTerminator()); + + return 0; + } + + Value *findDbgGlobalDeclare(GlobalVariable *V) { + const Module *M = V->getParent(); + const Type *Ty = M->getTypeByName("llvm.dbg.global_variable.type"); + if (!Ty) return 0; + + Ty = PointerType::get(Ty, 0); + + Value *Val = V->stripPointerCasts(); + for (Value::use_iterator I = Val->use_begin(), E = Val->use_end(); + I != E; ++I) { + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(I)) { + if (CE->getOpcode() == Instruction::BitCast) { + Value *VV = CE; + + while (VV->hasOneUse()) + VV = *VV->use_begin(); + + if (VV->getType() == Ty) + return VV; + } + } + } + + if (Val->getType() == Ty) + return Val; + + return 0; + } + + /// Finds the llvm.dbg.declare intrinsic corresponding to this value if any. + /// It looks through pointer casts too. + const DbgDeclareInst *findDbgDeclare(const Value *V, bool stripCasts) { + if (stripCasts) { + V = V->stripPointerCasts(); + + // Look for the bitcast. + for (Value::use_const_iterator I = V->use_begin(), E =V->use_end(); + I != E; ++I) + if (isa<BitCastInst>(I)) + return findDbgDeclare(*I, false); + + return 0; + } + + // Find llvm.dbg.declare among uses of the instruction. + for (Value::use_const_iterator I = V->use_begin(), E =V->use_end(); + I != E; ++I) + if (const DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(I)) + return DDI; + + return 0; + } + + bool getLocationInfo(const Value *V, std::string &DisplayName, + std::string &Type, unsigned &LineNo, std::string &File, + std::string &Dir) { + DICompileUnit Unit; + DIType TypeD; + + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(const_cast<Value*>(V))) { + Value *DIGV = findDbgGlobalDeclare(GV); + if (!DIGV) return false; + DIGlobalVariable Var(cast<GlobalVariable>(DIGV)); + + Var.getDisplayName(DisplayName); + LineNo = Var.getLineNumber(); + Unit = Var.getCompileUnit(); + TypeD = Var.getType(); + } else { + const DbgDeclareInst *DDI = findDbgDeclare(V); + if (!DDI) return false; + DIVariable Var(cast<GlobalVariable>(DDI->getVariable())); + + Var.getName(DisplayName); + LineNo = Var.getLineNumber(); + Unit = Var.getCompileUnit(); + TypeD = Var.getType(); + } + + TypeD.getName(Type); + Unit.getFilename(File); + Unit.getDirectory(Dir); + return true; + } +} + +/// dump - Print descriptor. +void DIDescriptor::dump() const { + cerr << "[" << dwarf::TagString(getTag()) << "] "; + cerr << std::hex << "[GV:" << GV << "]" << std::dec; +} + +/// dump - Print compile unit. +void DICompileUnit::dump() const { + if (getLanguage()) + cerr << " [" << dwarf::LanguageString(getLanguage()) << "] "; + + std::string Res1, Res2; + cerr << " [" << getDirectory(Res1) << "/" << getFilename(Res2) << " ]"; +} + +/// dump - Print type. +void DIType::dump() const { + if (isNull()) return; + + std::string Res; + if (!getName(Res).empty()) + cerr << " [" << Res << "] "; + + unsigned Tag = getTag(); + cerr << " [" << dwarf::TagString(Tag) << "] "; + + // TODO : Print context + getCompileUnit().dump(); + cerr << " [" + << getLineNumber() << ", " + << getSizeInBits() << ", " + << getAlignInBits() << ", " + << getOffsetInBits() + << "] "; + + if (isPrivate()) + cerr << " [private] "; + else if (isProtected()) + cerr << " [protected] "; + + if (isForwardDecl()) + cerr << " [fwd] "; + + if (isBasicType(Tag)) + DIBasicType(GV).dump(); + else if (isDerivedType(Tag)) + DIDerivedType(GV).dump(); + else if (isCompositeType(Tag)) + DICompositeType(GV).dump(); + else { + cerr << "Invalid DIType\n"; + return; + } + + cerr << "\n"; +} + +/// dump - Print basic type. +void DIBasicType::dump() const { + cerr << " [" << dwarf::AttributeEncodingString(getEncoding()) << "] "; +} + +/// dump - Print derived type. +void DIDerivedType::dump() const { + cerr << "\n\t Derived From: "; getTypeDerivedFrom().dump(); +} + +/// dump - Print composite type. +void DICompositeType::dump() const { + DIArray A = getTypeArray(); + if (A.isNull()) + return; + cerr << " [" << A.getNumElements() << " elements]"; +} + +/// dump - Print global. +void DIGlobal::dump() const { + std::string Res; + if (!getName(Res).empty()) + cerr << " [" << Res << "] "; + + unsigned Tag = getTag(); + cerr << " [" << dwarf::TagString(Tag) << "] "; + + // TODO : Print context + getCompileUnit().dump(); + cerr << " [" << getLineNumber() << "] "; + + if (isLocalToUnit()) + cerr << " [local] "; + + if (isDefinition()) + cerr << " [def] "; + + if (isGlobalVariable(Tag)) + DIGlobalVariable(GV).dump(); + + cerr << "\n"; +} + +/// dump - Print subprogram. +void DISubprogram::dump() const { + DIGlobal::dump(); +} + +/// dump - Print global variable. +void DIGlobalVariable::dump() const { + cerr << " ["; getGlobal()->dump(); cerr << "] "; +} + +/// dump - Print variable. +void DIVariable::dump() const { + std::string Res; + if (!getName(Res).empty()) + cerr << " [" << Res << "] "; + + getCompileUnit().dump(); + cerr << " [" << getLineNumber() << "] "; + getType().dump(); + cerr << "\n"; +} diff --git a/lib/Analysis/IPA/Andersens.cpp b/lib/Analysis/IPA/Andersens.cpp new file mode 100644 index 0000000..8584d06 --- /dev/null +++ b/lib/Analysis/IPA/Andersens.cpp @@ -0,0 +1,2878 @@ +//===- Andersens.cpp - Andersen's Interprocedural Alias Analysis ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines an implementation of Andersen's interprocedural alias +// analysis +// +// In pointer analysis terms, this is a subset-based, flow-insensitive, +// field-sensitive, and context-insensitive algorithm pointer algorithm. +// +// This algorithm is implemented as three stages: +// 1. Object identification. +// 2. Inclusion constraint identification. +// 3. Offline constraint graph optimization +// 4. Inclusion constraint solving. +// +// The object identification stage identifies all of the memory objects in the +// program, which includes globals, heap allocated objects, and stack allocated +// objects. +// +// The inclusion constraint identification stage finds all inclusion constraints +// in the program by scanning the program, looking for pointer assignments and +// other statements that effect the points-to graph. For a statement like "A = +// B", this statement is processed to indicate that A can point to anything that +// B can point to. Constraints can handle copies, loads, and stores, and +// address taking. +// +// The offline constraint graph optimization portion includes offline variable +// substitution algorithms intended to compute pointer and location +// equivalences. Pointer equivalences are those pointers that will have the +// same points-to sets, and location equivalences are those variables that +// always appear together in points-to sets. It also includes an offline +// cycle detection algorithm that allows cycles to be collapsed sooner +// during solving. +// +// The inclusion constraint solving phase iteratively propagates the inclusion +// constraints until a fixed point is reached. This is an O(N^3) algorithm. +// +// Function constraints are handled as if they were structs with X fields. +// Thus, an access to argument X of function Y is an access to node index +// getNode(Y) + X. This representation allows handling of indirect calls +// without any issues. To wit, an indirect call Y(a,b) is equivalent to +// *(Y + 1) = a, *(Y + 2) = b. +// The return node for a function is always located at getNode(F) + +// CallReturnPos. The arguments start at getNode(F) + CallArgPos. +// +// Future Improvements: +// Use of BDD's. +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "anders-aa" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Instructions.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/InstIterator.h" +#include "llvm/Support/InstVisitor.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/SparseBitVector.h" +#include "llvm/ADT/DenseSet.h" +#include <algorithm> +#include <set> +#include <list> +#include <map> +#include <stack> +#include <vector> +#include <queue> + +// Determining the actual set of nodes the universal set can consist of is very +// expensive because it means propagating around very large sets. We rely on +// other analysis being able to determine which nodes can never be pointed to in +// order to disambiguate further than "points-to anything". +#define FULL_UNIVERSAL 0 + +using namespace llvm; +STATISTIC(NumIters , "Number of iterations to reach convergence"); +STATISTIC(NumConstraints, "Number of constraints"); +STATISTIC(NumNodes , "Number of nodes"); +STATISTIC(NumUnified , "Number of variables unified"); +STATISTIC(NumErased , "Number of redundant constraints erased"); + +static const unsigned SelfRep = (unsigned)-1; +static const unsigned Unvisited = (unsigned)-1; +// Position of the function return node relative to the function node. +static const unsigned CallReturnPos = 1; +// Position of the function call node relative to the function node. +static const unsigned CallFirstArgPos = 2; + +namespace { + struct BitmapKeyInfo { + static inline SparseBitVector<> *getEmptyKey() { + return reinterpret_cast<SparseBitVector<> *>(-1); + } + static inline SparseBitVector<> *getTombstoneKey() { + return reinterpret_cast<SparseBitVector<> *>(-2); + } + static unsigned getHashValue(const SparseBitVector<> *bitmap) { + return bitmap->getHashValue(); + } + static bool isEqual(const SparseBitVector<> *LHS, + const SparseBitVector<> *RHS) { + if (LHS == RHS) + return true; + else if (LHS == getEmptyKey() || RHS == getEmptyKey() + || LHS == getTombstoneKey() || RHS == getTombstoneKey()) + return false; + + return *LHS == *RHS; + } + + static bool isPod() { return true; } + }; + + class VISIBILITY_HIDDEN Andersens : public ModulePass, public AliasAnalysis, + private InstVisitor<Andersens> { + struct Node; + + /// Constraint - Objects of this structure are used to represent the various + /// constraints identified by the algorithm. The constraints are 'copy', + /// for statements like "A = B", 'load' for statements like "A = *B", + /// 'store' for statements like "*A = B", and AddressOf for statements like + /// A = alloca; The Offset is applied as *(A + K) = B for stores, + /// A = *(B + K) for loads, and A = B + K for copies. It is + /// illegal on addressof constraints (because it is statically + /// resolvable to A = &C where C = B + K) + + struct Constraint { + enum ConstraintType { Copy, Load, Store, AddressOf } Type; + unsigned Dest; + unsigned Src; + unsigned Offset; + + Constraint(ConstraintType Ty, unsigned D, unsigned S, unsigned O = 0) + : Type(Ty), Dest(D), Src(S), Offset(O) { + assert((Offset == 0 || Ty != AddressOf) && + "Offset is illegal on addressof constraints"); + } + + bool operator==(const Constraint &RHS) const { + return RHS.Type == Type + && RHS.Dest == Dest + && RHS.Src == Src + && RHS.Offset == Offset; + } + + bool operator!=(const Constraint &RHS) const { + return !(*this == RHS); + } + + bool operator<(const Constraint &RHS) const { + if (RHS.Type != Type) + return RHS.Type < Type; + else if (RHS.Dest != Dest) + return RHS.Dest < Dest; + else if (RHS.Src != Src) + return RHS.Src < Src; + return RHS.Offset < Offset; + } + }; + + // Information DenseSet requires implemented in order to be able to do + // it's thing + struct PairKeyInfo { + static inline std::pair<unsigned, unsigned> getEmptyKey() { + return std::make_pair(~0U, ~0U); + } + static inline std::pair<unsigned, unsigned> getTombstoneKey() { + return std::make_pair(~0U - 1, ~0U - 1); + } + static unsigned getHashValue(const std::pair<unsigned, unsigned> &P) { + return P.first ^ P.second; + } + static unsigned isEqual(const std::pair<unsigned, unsigned> &LHS, + const std::pair<unsigned, unsigned> &RHS) { + return LHS == RHS; + } + }; + + struct ConstraintKeyInfo { + static inline Constraint getEmptyKey() { + return Constraint(Constraint::Copy, ~0U, ~0U, ~0U); + } + static inline Constraint getTombstoneKey() { + return Constraint(Constraint::Copy, ~0U - 1, ~0U - 1, ~0U - 1); + } + static unsigned getHashValue(const Constraint &C) { + return C.Src ^ C.Dest ^ C.Type ^ C.Offset; + } + static bool isEqual(const Constraint &LHS, + const Constraint &RHS) { + return LHS.Type == RHS.Type && LHS.Dest == RHS.Dest + && LHS.Src == RHS.Src && LHS.Offset == RHS.Offset; + } + }; + + // Node class - This class is used to represent a node in the constraint + // graph. Due to various optimizations, it is not always the case that + // there is a mapping from a Node to a Value. In particular, we add + // artificial Node's that represent the set of pointed-to variables shared + // for each location equivalent Node. + struct Node { + private: + static unsigned Counter; + + public: + Value *Val; + SparseBitVector<> *Edges; + SparseBitVector<> *PointsTo; + SparseBitVector<> *OldPointsTo; + std::list<Constraint> Constraints; + + // Pointer and location equivalence labels + unsigned PointerEquivLabel; + unsigned LocationEquivLabel; + // Predecessor edges, both real and implicit + SparseBitVector<> *PredEdges; + SparseBitVector<> *ImplicitPredEdges; + // Set of nodes that point to us, only use for location equivalence. + SparseBitVector<> *PointedToBy; + // Number of incoming edges, used during variable substitution to early + // free the points-to sets + unsigned NumInEdges; + // True if our points-to set is in the Set2PEClass map + bool StoredInHash; + // True if our node has no indirect constraints (complex or otherwise) + bool Direct; + // True if the node is address taken, *or* it is part of a group of nodes + // that must be kept together. This is set to true for functions and + // their arg nodes, which must be kept at the same position relative to + // their base function node. + bool AddressTaken; + + // Nodes in cycles (or in equivalence classes) are united together using a + // standard union-find representation with path compression. NodeRep + // gives the index into GraphNodes for the representative Node. + unsigned NodeRep; + + // Modification timestamp. Assigned from Counter. + // Used for work list prioritization. + unsigned Timestamp; + + explicit Node(bool direct = true) : + Val(0), Edges(0), PointsTo(0), OldPointsTo(0), + PointerEquivLabel(0), LocationEquivLabel(0), PredEdges(0), + ImplicitPredEdges(0), PointedToBy(0), NumInEdges(0), + StoredInHash(false), Direct(direct), AddressTaken(false), + NodeRep(SelfRep), Timestamp(0) { } + + Node *setValue(Value *V) { + assert(Val == 0 && "Value already set for this node!"); + Val = V; + return this; + } + + /// getValue - Return the LLVM value corresponding to this node. + /// + Value *getValue() const { return Val; } + + /// addPointerTo - Add a pointer to the list of pointees of this node, + /// returning true if this caused a new pointer to be added, or false if + /// we already knew about the points-to relation. + bool addPointerTo(unsigned Node) { + return PointsTo->test_and_set(Node); + } + + /// intersects - Return true if the points-to set of this node intersects + /// with the points-to set of the specified node. + bool intersects(Node *N) const; + + /// intersectsIgnoring - Return true if the points-to set of this node + /// intersects with the points-to set of the specified node on any nodes + /// except for the specified node to ignore. + bool intersectsIgnoring(Node *N, unsigned) const; + + // Timestamp a node (used for work list prioritization) + void Stamp() { + Timestamp = Counter++; + } + + bool isRep() const { + return( (int) NodeRep < 0 ); + } + }; + + struct WorkListElement { + Node* node; + unsigned Timestamp; + WorkListElement(Node* n, unsigned t) : node(n), Timestamp(t) {} + + // Note that we reverse the sense of the comparison because we + // actually want to give low timestamps the priority over high, + // whereas priority is typically interpreted as a greater value is + // given high priority. + bool operator<(const WorkListElement& that) const { + return( this->Timestamp > that.Timestamp ); + } + }; + + // Priority-queue based work list specialized for Nodes. + class WorkList { + std::priority_queue<WorkListElement> Q; + + public: + void insert(Node* n) { + Q.push( WorkListElement(n, n->Timestamp) ); + } + + // We automatically discard non-representative nodes and nodes + // that were in the work list twice (we keep a copy of the + // timestamp in the work list so we can detect this situation by + // comparing against the node's current timestamp). + Node* pop() { + while( !Q.empty() ) { + WorkListElement x = Q.top(); Q.pop(); + Node* INode = x.node; + + if( INode->isRep() && + INode->Timestamp == x.Timestamp ) { + return(x.node); + } + } + return(0); + } + + bool empty() { + return Q.empty(); + } + }; + + /// GraphNodes - This vector is populated as part of the object + /// identification stage of the analysis, which populates this vector with a + /// node for each memory object and fills in the ValueNodes map. + std::vector<Node> GraphNodes; + + /// ValueNodes - This map indicates the Node that a particular Value* is + /// represented by. This contains entries for all pointers. + DenseMap<Value*, unsigned> ValueNodes; + + /// ObjectNodes - This map contains entries for each memory object in the + /// program: globals, alloca's and mallocs. + DenseMap<Value*, unsigned> ObjectNodes; + + /// ReturnNodes - This map contains an entry for each function in the + /// program that returns a value. + DenseMap<Function*, unsigned> ReturnNodes; + + /// VarargNodes - This map contains the entry used to represent all pointers + /// passed through the varargs portion of a function call for a particular + /// function. An entry is not present in this map for functions that do not + /// take variable arguments. + DenseMap<Function*, unsigned> VarargNodes; + + + /// Constraints - This vector contains a list of all of the constraints + /// identified by the program. + std::vector<Constraint> Constraints; + + // Map from graph node to maximum K value that is allowed (for functions, + // this is equivalent to the number of arguments + CallFirstArgPos) + std::map<unsigned, unsigned> MaxK; + + /// This enum defines the GraphNodes indices that correspond to important + /// fixed sets. + enum { + UniversalSet = 0, + NullPtr = 1, + NullObject = 2, + NumberSpecialNodes + }; + // Stack for Tarjan's + std::stack<unsigned> SCCStack; + // Map from Graph Node to DFS number + std::vector<unsigned> Node2DFS; + // Map from Graph Node to Deleted from graph. + std::vector<bool> Node2Deleted; + // Same as Node Maps, but implemented as std::map because it is faster to + // clear + std::map<unsigned, unsigned> Tarjan2DFS; + std::map<unsigned, bool> Tarjan2Deleted; + // Current DFS number + unsigned DFSNumber; + + // Work lists. + WorkList w1, w2; + WorkList *CurrWL, *NextWL; // "current" and "next" work lists + + // Offline variable substitution related things + + // Temporary rep storage, used because we can't collapse SCC's in the + // predecessor graph by uniting the variables permanently, we can only do so + // for the successor graph. + std::vector<unsigned> VSSCCRep; + // Mapping from node to whether we have visited it during SCC finding yet. + std::vector<bool> Node2Visited; + // During variable substitution, we create unknowns to represent the unknown + // value that is a dereference of a variable. These nodes are known as + // "ref" nodes (since they represent the value of dereferences). + unsigned FirstRefNode; + // During HVN, we create represent address taken nodes as if they were + // unknown (since HVN, unlike HU, does not evaluate unions). + unsigned FirstAdrNode; + // Current pointer equivalence class number + unsigned PEClass; + // Mapping from points-to sets to equivalence classes + typedef DenseMap<SparseBitVector<> *, unsigned, BitmapKeyInfo> BitVectorMap; + BitVectorMap Set2PEClass; + // Mapping from pointer equivalences to the representative node. -1 if we + // have no representative node for this pointer equivalence class yet. + std::vector<int> PEClass2Node; + // Mapping from pointer equivalences to representative node. This includes + // pointer equivalent but not location equivalent variables. -1 if we have + // no representative node for this pointer equivalence class yet. + std::vector<int> PENLEClass2Node; + // Union/Find for HCD + std::vector<unsigned> HCDSCCRep; + // HCD's offline-detected cycles; "Statically DeTected" + // -1 if not part of such a cycle, otherwise a representative node. + std::vector<int> SDT; + // Whether to use SDT (UniteNodes can use it during solving, but not before) + bool SDTActive; + + public: + static char ID; + Andersens() : ModulePass(&ID) {} + + bool runOnModule(Module &M) { + InitializeAliasAnalysis(this); + IdentifyObjects(M); + CollectConstraints(M); +#undef DEBUG_TYPE +#define DEBUG_TYPE "anders-aa-constraints" + DEBUG(PrintConstraints()); +#undef DEBUG_TYPE +#define DEBUG_TYPE "anders-aa" + SolveConstraints(); + DEBUG(PrintPointsToGraph()); + + // Free the constraints list, as we don't need it to respond to alias + // requests. + std::vector<Constraint>().swap(Constraints); + //These are needed for Print() (-analyze in opt) + //ObjectNodes.clear(); + //ReturnNodes.clear(); + //VarargNodes.clear(); + return false; + } + + void releaseMemory() { + // FIXME: Until we have transitively required passes working correctly, + // this cannot be enabled! Otherwise, using -count-aa with the pass + // causes memory to be freed too early. :( +#if 0 + // The memory objects and ValueNodes data structures at the only ones that + // are still live after construction. + std::vector<Node>().swap(GraphNodes); + ValueNodes.clear(); +#endif + } + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AliasAnalysis::getAnalysisUsage(AU); + AU.setPreservesAll(); // Does not transform code + } + + //------------------------------------------------ + // Implement the AliasAnalysis API + // + AliasResult alias(const Value *V1, unsigned V1Size, + const Value *V2, unsigned V2Size); + virtual ModRefResult getModRefInfo(CallSite CS, Value *P, unsigned Size); + virtual ModRefResult getModRefInfo(CallSite CS1, CallSite CS2); + void getMustAliases(Value *P, std::vector<Value*> &RetVals); + bool pointsToConstantMemory(const Value *P); + + virtual void deleteValue(Value *V) { + ValueNodes.erase(V); + getAnalysis<AliasAnalysis>().deleteValue(V); + } + + virtual void copyValue(Value *From, Value *To) { + ValueNodes[To] = ValueNodes[From]; + getAnalysis<AliasAnalysis>().copyValue(From, To); + } + + private: + /// getNode - Return the node corresponding to the specified pointer scalar. + /// + unsigned getNode(Value *V) { + if (Constant *C = dyn_cast<Constant>(V)) + if (!isa<GlobalValue>(C)) + return getNodeForConstantPointer(C); + + DenseMap<Value*, unsigned>::iterator I = ValueNodes.find(V); + if (I == ValueNodes.end()) { +#ifndef NDEBUG + V->dump(); +#endif + assert(0 && "Value does not have a node in the points-to graph!"); + } + return I->second; + } + + /// getObject - Return the node corresponding to the memory object for the + /// specified global or allocation instruction. + unsigned getObject(Value *V) const { + DenseMap<Value*, unsigned>::iterator I = ObjectNodes.find(V); + assert(I != ObjectNodes.end() && + "Value does not have an object in the points-to graph!"); + return I->second; + } + + /// getReturnNode - Return the node representing the return value for the + /// specified function. + unsigned getReturnNode(Function *F) const { + DenseMap<Function*, unsigned>::iterator I = ReturnNodes.find(F); + assert(I != ReturnNodes.end() && "Function does not return a value!"); + return I->second; + } + + /// getVarargNode - Return the node representing the variable arguments + /// formal for the specified function. + unsigned getVarargNode(Function *F) const { + DenseMap<Function*, unsigned>::iterator I = VarargNodes.find(F); + assert(I != VarargNodes.end() && "Function does not take var args!"); + return I->second; + } + + /// getNodeValue - Get the node for the specified LLVM value and set the + /// value for it to be the specified value. + unsigned getNodeValue(Value &V) { + unsigned Index = getNode(&V); + GraphNodes[Index].setValue(&V); + return Index; + } + + unsigned UniteNodes(unsigned First, unsigned Second, + bool UnionByRank = true); + unsigned FindNode(unsigned Node); + unsigned FindNode(unsigned Node) const; + + void IdentifyObjects(Module &M); + void CollectConstraints(Module &M); + bool AnalyzeUsesOfFunction(Value *); + void CreateConstraintGraph(); + void OptimizeConstraints(); + unsigned FindEquivalentNode(unsigned, unsigned); + void ClumpAddressTaken(); + void RewriteConstraints(); + void HU(); + void HVN(); + void HCD(); + void Search(unsigned Node); + void UnitePointerEquivalences(); + void SolveConstraints(); + bool QueryNode(unsigned Node); + void Condense(unsigned Node); + void HUValNum(unsigned Node); + void HVNValNum(unsigned Node); + unsigned getNodeForConstantPointer(Constant *C); + unsigned getNodeForConstantPointerTarget(Constant *C); + void AddGlobalInitializerConstraints(unsigned, Constant *C); + + void AddConstraintsForNonInternalLinkage(Function *F); + void AddConstraintsForCall(CallSite CS, Function *F); + bool AddConstraintsForExternalCall(CallSite CS, Function *F); + + + void PrintNode(const Node *N) const; + void PrintConstraints() const ; + void PrintConstraint(const Constraint &) const; + void PrintLabels() const; + void PrintPointsToGraph() const; + + //===------------------------------------------------------------------===// + // Instruction visitation methods for adding constraints + // + friend class InstVisitor<Andersens>; + void visitReturnInst(ReturnInst &RI); + void visitInvokeInst(InvokeInst &II) { visitCallSite(CallSite(&II)); } + void visitCallInst(CallInst &CI) { visitCallSite(CallSite(&CI)); } + void visitCallSite(CallSite CS); + void visitAllocationInst(AllocationInst &AI); + void visitLoadInst(LoadInst &LI); + void visitStoreInst(StoreInst &SI); + void visitGetElementPtrInst(GetElementPtrInst &GEP); + void visitPHINode(PHINode &PN); + void visitCastInst(CastInst &CI); + void visitICmpInst(ICmpInst &ICI) {} // NOOP! + void visitFCmpInst(FCmpInst &ICI) {} // NOOP! + void visitSelectInst(SelectInst &SI); + void visitVAArg(VAArgInst &I); + void visitInstruction(Instruction &I); + + //===------------------------------------------------------------------===// + // Implement Analyize interface + // + void print(std::ostream &O, const Module* M) const { + PrintPointsToGraph(); + } + }; +} + +char Andersens::ID = 0; +static RegisterPass<Andersens> +X("anders-aa", "Andersen's Interprocedural Alias Analysis", false, true); +static RegisterAnalysisGroup<AliasAnalysis> Y(X); + +// Initialize Timestamp Counter (static). +unsigned Andersens::Node::Counter = 0; + +ModulePass *llvm::createAndersensPass() { return new Andersens(); } + +//===----------------------------------------------------------------------===// +// AliasAnalysis Interface Implementation +//===----------------------------------------------------------------------===// + +AliasAnalysis::AliasResult Andersens::alias(const Value *V1, unsigned V1Size, + const Value *V2, unsigned V2Size) { + Node *N1 = &GraphNodes[FindNode(getNode(const_cast<Value*>(V1)))]; + Node *N2 = &GraphNodes[FindNode(getNode(const_cast<Value*>(V2)))]; + + // Check to see if the two pointers are known to not alias. They don't alias + // if their points-to sets do not intersect. + if (!N1->intersectsIgnoring(N2, NullObject)) + return NoAlias; + + return AliasAnalysis::alias(V1, V1Size, V2, V2Size); +} + +AliasAnalysis::ModRefResult +Andersens::getModRefInfo(CallSite CS, Value *P, unsigned Size) { + // The only thing useful that we can contribute for mod/ref information is + // when calling external function calls: if we know that memory never escapes + // from the program, it cannot be modified by an external call. + // + // NOTE: This is not really safe, at least not when the entire program is not + // available. The deal is that the external function could call back into the + // program and modify stuff. We ignore this technical niggle for now. This + // is, after all, a "research quality" implementation of Andersen's analysis. + if (Function *F = CS.getCalledFunction()) + if (F->isDeclaration()) { + Node *N1 = &GraphNodes[FindNode(getNode(P))]; + + if (N1->PointsTo->empty()) + return NoModRef; +#if FULL_UNIVERSAL + if (!UniversalSet->PointsTo->test(FindNode(getNode(P)))) + return NoModRef; // Universal set does not contain P +#else + if (!N1->PointsTo->test(UniversalSet)) + return NoModRef; // P doesn't point to the universal set. +#endif + } + + return AliasAnalysis::getModRefInfo(CS, P, Size); +} + +AliasAnalysis::ModRefResult +Andersens::getModRefInfo(CallSite CS1, CallSite CS2) { + return AliasAnalysis::getModRefInfo(CS1,CS2); +} + +/// getMustAlias - We can provide must alias information if we know that a +/// pointer can only point to a specific function or the null pointer. +/// Unfortunately we cannot determine must-alias information for global +/// variables or any other memory memory objects because we do not track whether +/// a pointer points to the beginning of an object or a field of it. +void Andersens::getMustAliases(Value *P, std::vector<Value*> &RetVals) { + Node *N = &GraphNodes[FindNode(getNode(P))]; + if (N->PointsTo->count() == 1) { + Node *Pointee = &GraphNodes[N->PointsTo->find_first()]; + // If a function is the only object in the points-to set, then it must be + // the destination. Note that we can't handle global variables here, + // because we don't know if the pointer is actually pointing to a field of + // the global or to the beginning of it. + if (Value *V = Pointee->getValue()) { + if (Function *F = dyn_cast<Function>(V)) + RetVals.push_back(F); + } else { + // If the object in the points-to set is the null object, then the null + // pointer is a must alias. + if (Pointee == &GraphNodes[NullObject]) + RetVals.push_back(Constant::getNullValue(P->getType())); + } + } + AliasAnalysis::getMustAliases(P, RetVals); +} + +/// pointsToConstantMemory - If we can determine that this pointer only points +/// to constant memory, return true. In practice, this means that if the +/// pointer can only point to constant globals, functions, or the null pointer, +/// return true. +/// +bool Andersens::pointsToConstantMemory(const Value *P) { + Node *N = &GraphNodes[FindNode(getNode(const_cast<Value*>(P)))]; + unsigned i; + + for (SparseBitVector<>::iterator bi = N->PointsTo->begin(); + bi != N->PointsTo->end(); + ++bi) { + i = *bi; + Node *Pointee = &GraphNodes[i]; + if (Value *V = Pointee->getValue()) { + if (!isa<GlobalValue>(V) || (isa<GlobalVariable>(V) && + !cast<GlobalVariable>(V)->isConstant())) + return AliasAnalysis::pointsToConstantMemory(P); + } else { + if (i != NullObject) + return AliasAnalysis::pointsToConstantMemory(P); + } + } + + return true; +} + +//===----------------------------------------------------------------------===// +// Object Identification Phase +//===----------------------------------------------------------------------===// + +/// IdentifyObjects - This stage scans the program, adding an entry to the +/// GraphNodes list for each memory object in the program (global stack or +/// heap), and populates the ValueNodes and ObjectNodes maps for these objects. +/// +void Andersens::IdentifyObjects(Module &M) { + unsigned NumObjects = 0; + + // Object #0 is always the universal set: the object that we don't know + // anything about. + assert(NumObjects == UniversalSet && "Something changed!"); + ++NumObjects; + + // Object #1 always represents the null pointer. + assert(NumObjects == NullPtr && "Something changed!"); + ++NumObjects; + + // Object #2 always represents the null object (the object pointed to by null) + assert(NumObjects == NullObject && "Something changed!"); + ++NumObjects; + + // Add all the globals first. + for (Module::global_iterator I = M.global_begin(), E = M.global_end(); + I != E; ++I) { + ObjectNodes[I] = NumObjects++; + ValueNodes[I] = NumObjects++; + } + + // Add nodes for all of the functions and the instructions inside of them. + for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) { + // The function itself is a memory object. + unsigned First = NumObjects; + ValueNodes[F] = NumObjects++; + if (isa<PointerType>(F->getFunctionType()->getReturnType())) + ReturnNodes[F] = NumObjects++; + if (F->getFunctionType()->isVarArg()) + VarargNodes[F] = NumObjects++; + + + // Add nodes for all of the incoming pointer arguments. + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); + I != E; ++I) + { + if (isa<PointerType>(I->getType())) + ValueNodes[I] = NumObjects++; + } + MaxK[First] = NumObjects - First; + + // Scan the function body, creating a memory object for each heap/stack + // allocation in the body of the function and a node to represent all + // pointer values defined by instructions and used as operands. + for (inst_iterator II = inst_begin(F), E = inst_end(F); II != E; ++II) { + // If this is an heap or stack allocation, create a node for the memory + // object. + if (isa<PointerType>(II->getType())) { + ValueNodes[&*II] = NumObjects++; + if (AllocationInst *AI = dyn_cast<AllocationInst>(&*II)) + ObjectNodes[AI] = NumObjects++; + } + + // Calls to inline asm need to be added as well because the callee isn't + // referenced anywhere else. + if (CallInst *CI = dyn_cast<CallInst>(&*II)) { + Value *Callee = CI->getCalledValue(); + if (isa<InlineAsm>(Callee)) + ValueNodes[Callee] = NumObjects++; + } + } + } + + // Now that we know how many objects to create, make them all now! + GraphNodes.resize(NumObjects); + NumNodes += NumObjects; +} + +//===----------------------------------------------------------------------===// +// Constraint Identification Phase +//===----------------------------------------------------------------------===// + +/// getNodeForConstantPointer - Return the node corresponding to the constant +/// pointer itself. +unsigned Andersens::getNodeForConstantPointer(Constant *C) { + assert(isa<PointerType>(C->getType()) && "Not a constant pointer!"); + + if (isa<ConstantPointerNull>(C) || isa<UndefValue>(C)) + return NullPtr; + else if (GlobalValue *GV = dyn_cast<GlobalValue>(C)) + return getNode(GV); + else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) { + switch (CE->getOpcode()) { + case Instruction::GetElementPtr: + return getNodeForConstantPointer(CE->getOperand(0)); + case Instruction::IntToPtr: + return UniversalSet; + case Instruction::BitCast: + return getNodeForConstantPointer(CE->getOperand(0)); + default: + cerr << "Constant Expr not yet handled: " << *CE << "\n"; + assert(0); + } + } else { + assert(0 && "Unknown constant pointer!"); + } + return 0; +} + +/// getNodeForConstantPointerTarget - Return the node POINTED TO by the +/// specified constant pointer. +unsigned Andersens::getNodeForConstantPointerTarget(Constant *C) { + assert(isa<PointerType>(C->getType()) && "Not a constant pointer!"); + + if (isa<ConstantPointerNull>(C)) + return NullObject; + else if (GlobalValue *GV = dyn_cast<GlobalValue>(C)) + return getObject(GV); + else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) { + switch (CE->getOpcode()) { + case Instruction::GetElementPtr: + return getNodeForConstantPointerTarget(CE->getOperand(0)); + case Instruction::IntToPtr: + return UniversalSet; + case Instruction::BitCast: + return getNodeForConstantPointerTarget(CE->getOperand(0)); + default: + cerr << "Constant Expr not yet handled: " << *CE << "\n"; + assert(0); + } + } else { + assert(0 && "Unknown constant pointer!"); + } + return 0; +} + +/// AddGlobalInitializerConstraints - Add inclusion constraints for the memory +/// object N, which contains values indicated by C. +void Andersens::AddGlobalInitializerConstraints(unsigned NodeIndex, + Constant *C) { + if (C->getType()->isSingleValueType()) { + if (isa<PointerType>(C->getType())) + Constraints.push_back(Constraint(Constraint::Copy, NodeIndex, + getNodeForConstantPointer(C))); + } else if (C->isNullValue()) { + Constraints.push_back(Constraint(Constraint::Copy, NodeIndex, + NullObject)); + return; + } else if (!isa<UndefValue>(C)) { + // If this is an array or struct, include constraints for each element. + assert(isa<ConstantArray>(C) || isa<ConstantStruct>(C)); + for (unsigned i = 0, e = C->getNumOperands(); i != e; ++i) + AddGlobalInitializerConstraints(NodeIndex, + cast<Constant>(C->getOperand(i))); + } +} + +/// AddConstraintsForNonInternalLinkage - If this function does not have +/// internal linkage, realize that we can't trust anything passed into or +/// returned by this function. +void Andersens::AddConstraintsForNonInternalLinkage(Function *F) { + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E; ++I) + if (isa<PointerType>(I->getType())) + // If this is an argument of an externally accessible function, the + // incoming pointer might point to anything. + Constraints.push_back(Constraint(Constraint::Copy, getNode(I), + UniversalSet)); +} + +/// AddConstraintsForCall - If this is a call to a "known" function, add the +/// constraints and return true. If this is a call to an unknown function, +/// return false. +bool Andersens::AddConstraintsForExternalCall(CallSite CS, Function *F) { + assert(F->isDeclaration() && "Not an external function!"); + + // These functions don't induce any points-to constraints. + if (F->getName() == "atoi" || F->getName() == "atof" || + F->getName() == "atol" || F->getName() == "atoll" || + F->getName() == "remove" || F->getName() == "unlink" || + F->getName() == "rename" || F->getName() == "memcmp" || + F->getName() == "llvm.memset" || + F->getName() == "strcmp" || F->getName() == "strncmp" || + F->getName() == "execl" || F->getName() == "execlp" || + F->getName() == "execle" || F->getName() == "execv" || + F->getName() == "execvp" || F->getName() == "chmod" || + F->getName() == "puts" || F->getName() == "write" || + F->getName() == "open" || F->getName() == "create" || + F->getName() == "truncate" || F->getName() == "chdir" || + F->getName() == "mkdir" || F->getName() == "rmdir" || + F->getName() == "read" || F->getName() == "pipe" || + F->getName() == "wait" || F->getName() == "time" || + F->getName() == "stat" || F->getName() == "fstat" || + F->getName() == "lstat" || F->getName() == "strtod" || + F->getName() == "strtof" || F->getName() == "strtold" || + F->getName() == "fopen" || F->getName() == "fdopen" || + F->getName() == "freopen" || + F->getName() == "fflush" || F->getName() == "feof" || + F->getName() == "fileno" || F->getName() == "clearerr" || + F->getName() == "rewind" || F->getName() == "ftell" || + F->getName() == "ferror" || F->getName() == "fgetc" || + F->getName() == "fgetc" || F->getName() == "_IO_getc" || + F->getName() == "fwrite" || F->getName() == "fread" || + F->getName() == "fgets" || F->getName() == "ungetc" || + F->getName() == "fputc" || + F->getName() == "fputs" || F->getName() == "putc" || + F->getName() == "ftell" || F->getName() == "rewind" || + F->getName() == "_IO_putc" || F->getName() == "fseek" || + F->getName() == "fgetpos" || F->getName() == "fsetpos" || + F->getName() == "printf" || F->getName() == "fprintf" || + F->getName() == "sprintf" || F->getName() == "vprintf" || + F->getName() == "vfprintf" || F->getName() == "vsprintf" || + F->getName() == "scanf" || F->getName() == "fscanf" || + F->getName() == "sscanf" || F->getName() == "__assert_fail" || + F->getName() == "modf") + return true; + + + // These functions do induce points-to edges. + if (F->getName() == "llvm.memcpy" || + F->getName() == "llvm.memmove" || + F->getName() == "memmove") { + + const FunctionType *FTy = F->getFunctionType(); + if (FTy->getNumParams() > 1 && + isa<PointerType>(FTy->getParamType(0)) && + isa<PointerType>(FTy->getParamType(1))) { + + // *Dest = *Src, which requires an artificial graph node to represent the + // constraint. It is broken up into *Dest = temp, temp = *Src + unsigned FirstArg = getNode(CS.getArgument(0)); + unsigned SecondArg = getNode(CS.getArgument(1)); + unsigned TempArg = GraphNodes.size(); + GraphNodes.push_back(Node()); + Constraints.push_back(Constraint(Constraint::Store, + FirstArg, TempArg)); + Constraints.push_back(Constraint(Constraint::Load, + TempArg, SecondArg)); + // In addition, Dest = Src + Constraints.push_back(Constraint(Constraint::Copy, + FirstArg, SecondArg)); + return true; + } + } + + // Result = Arg0 + if (F->getName() == "realloc" || F->getName() == "strchr" || + F->getName() == "strrchr" || F->getName() == "strstr" || + F->getName() == "strtok") { + const FunctionType *FTy = F->getFunctionType(); + if (FTy->getNumParams() > 0 && + isa<PointerType>(FTy->getParamType(0))) { + Constraints.push_back(Constraint(Constraint::Copy, + getNode(CS.getInstruction()), + getNode(CS.getArgument(0)))); + return true; + } + } + + return false; +} + + + +/// AnalyzeUsesOfFunction - Look at all of the users of the specified function. +/// If this is used by anything complex (i.e., the address escapes), return +/// true. +bool Andersens::AnalyzeUsesOfFunction(Value *V) { + + if (!isa<PointerType>(V->getType())) return true; + + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI != E; ++UI) + if (dyn_cast<LoadInst>(*UI)) { + return false; + } else if (StoreInst *SI = dyn_cast<StoreInst>(*UI)) { + if (V == SI->getOperand(1)) { + return false; + } else if (SI->getOperand(1)) { + return true; // Storing the pointer + } + } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(*UI)) { + if (AnalyzeUsesOfFunction(GEP)) return true; + } else if (CallInst *CI = dyn_cast<CallInst>(*UI)) { + // Make sure that this is just the function being called, not that it is + // passing into the function. + for (unsigned i = 1, e = CI->getNumOperands(); i != e; ++i) + if (CI->getOperand(i) == V) return true; + } else if (InvokeInst *II = dyn_cast<InvokeInst>(*UI)) { + // Make sure that this is just the function being called, not that it is + // passing into the function. + for (unsigned i = 3, e = II->getNumOperands(); i != e; ++i) + if (II->getOperand(i) == V) return true; + } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(*UI)) { + if (CE->getOpcode() == Instruction::GetElementPtr || + CE->getOpcode() == Instruction::BitCast) { + if (AnalyzeUsesOfFunction(CE)) + return true; + } else { + return true; + } + } else if (ICmpInst *ICI = dyn_cast<ICmpInst>(*UI)) { + if (!isa<ConstantPointerNull>(ICI->getOperand(1))) + return true; // Allow comparison against null. + } else if (dyn_cast<FreeInst>(*UI)) { + return false; + } else { + return true; + } + return false; +} + +/// CollectConstraints - This stage scans the program, adding a constraint to +/// the Constraints list for each instruction in the program that induces a +/// constraint, and setting up the initial points-to graph. +/// +void Andersens::CollectConstraints(Module &M) { + // First, the universal set points to itself. + Constraints.push_back(Constraint(Constraint::AddressOf, UniversalSet, + UniversalSet)); + Constraints.push_back(Constraint(Constraint::Store, UniversalSet, + UniversalSet)); + + // Next, the null pointer points to the null object. + Constraints.push_back(Constraint(Constraint::AddressOf, NullPtr, NullObject)); + + // Next, add any constraints on global variables and their initializers. + for (Module::global_iterator I = M.global_begin(), E = M.global_end(); + I != E; ++I) { + // Associate the address of the global object as pointing to the memory for + // the global: &G = <G memory> + unsigned ObjectIndex = getObject(I); + Node *Object = &GraphNodes[ObjectIndex]; + Object->setValue(I); + Constraints.push_back(Constraint(Constraint::AddressOf, getNodeValue(*I), + ObjectIndex)); + + if (I->hasInitializer()) { + AddGlobalInitializerConstraints(ObjectIndex, I->getInitializer()); + } else { + // If it doesn't have an initializer (i.e. it's defined in another + // translation unit), it points to the universal set. + Constraints.push_back(Constraint(Constraint::Copy, ObjectIndex, + UniversalSet)); + } + } + + for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) { + // Set up the return value node. + if (isa<PointerType>(F->getFunctionType()->getReturnType())) + GraphNodes[getReturnNode(F)].setValue(F); + if (F->getFunctionType()->isVarArg()) + GraphNodes[getVarargNode(F)].setValue(F); + + // Set up incoming argument nodes. + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); + I != E; ++I) + if (isa<PointerType>(I->getType())) + getNodeValue(*I); + + // At some point we should just add constraints for the escaping functions + // at solve time, but this slows down solving. For now, we simply mark + // address taken functions as escaping and treat them as external. + if (!F->hasLocalLinkage() || AnalyzeUsesOfFunction(F)) + AddConstraintsForNonInternalLinkage(F); + + if (!F->isDeclaration()) { + // Scan the function body, creating a memory object for each heap/stack + // allocation in the body of the function and a node to represent all + // pointer values defined by instructions and used as operands. + visit(F); + } else { + // External functions that return pointers return the universal set. + if (isa<PointerType>(F->getFunctionType()->getReturnType())) + Constraints.push_back(Constraint(Constraint::Copy, + getReturnNode(F), + UniversalSet)); + + // Any pointers that are passed into the function have the universal set + // stored into them. + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); + I != E; ++I) + if (isa<PointerType>(I->getType())) { + // Pointers passed into external functions could have anything stored + // through them. + Constraints.push_back(Constraint(Constraint::Store, getNode(I), + UniversalSet)); + // Memory objects passed into external function calls can have the + // universal set point to them. +#if FULL_UNIVERSAL + Constraints.push_back(Constraint(Constraint::Copy, + UniversalSet, + getNode(I))); +#else + Constraints.push_back(Constraint(Constraint::Copy, + getNode(I), + UniversalSet)); +#endif + } + + // If this is an external varargs function, it can also store pointers + // into any pointers passed through the varargs section. + if (F->getFunctionType()->isVarArg()) + Constraints.push_back(Constraint(Constraint::Store, getVarargNode(F), + UniversalSet)); + } + } + NumConstraints += Constraints.size(); +} + + +void Andersens::visitInstruction(Instruction &I) { +#ifdef NDEBUG + return; // This function is just a big assert. +#endif + if (isa<BinaryOperator>(I)) + return; + // Most instructions don't have any effect on pointer values. + switch (I.getOpcode()) { + case Instruction::Br: + case Instruction::Switch: + case Instruction::Unwind: + case Instruction::Unreachable: + case Instruction::Free: + case Instruction::ICmp: + case Instruction::FCmp: + return; + default: + // Is this something we aren't handling yet? + cerr << "Unknown instruction: " << I; + abort(); + } +} + +void Andersens::visitAllocationInst(AllocationInst &AI) { + unsigned ObjectIndex = getObject(&AI); + GraphNodes[ObjectIndex].setValue(&AI); + Constraints.push_back(Constraint(Constraint::AddressOf, getNodeValue(AI), + ObjectIndex)); +} + +void Andersens::visitReturnInst(ReturnInst &RI) { + if (RI.getNumOperands() && isa<PointerType>(RI.getOperand(0)->getType())) + // return V --> <Copy/retval{F}/v> + Constraints.push_back(Constraint(Constraint::Copy, + getReturnNode(RI.getParent()->getParent()), + getNode(RI.getOperand(0)))); +} + +void Andersens::visitLoadInst(LoadInst &LI) { + if (isa<PointerType>(LI.getType())) + // P1 = load P2 --> <Load/P1/P2> + Constraints.push_back(Constraint(Constraint::Load, getNodeValue(LI), + getNode(LI.getOperand(0)))); +} + +void Andersens::visitStoreInst(StoreInst &SI) { + if (isa<PointerType>(SI.getOperand(0)->getType())) + // store P1, P2 --> <Store/P2/P1> + Constraints.push_back(Constraint(Constraint::Store, + getNode(SI.getOperand(1)), + getNode(SI.getOperand(0)))); +} + +void Andersens::visitGetElementPtrInst(GetElementPtrInst &GEP) { + // P1 = getelementptr P2, ... --> <Copy/P1/P2> + Constraints.push_back(Constraint(Constraint::Copy, getNodeValue(GEP), + getNode(GEP.getOperand(0)))); +} + +void Andersens::visitPHINode(PHINode &PN) { + if (isa<PointerType>(PN.getType())) { + unsigned PNN = getNodeValue(PN); + for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) + // P1 = phi P2, P3 --> <Copy/P1/P2>, <Copy/P1/P3>, ... + Constraints.push_back(Constraint(Constraint::Copy, PNN, + getNode(PN.getIncomingValue(i)))); + } +} + +void Andersens::visitCastInst(CastInst &CI) { + Value *Op = CI.getOperand(0); + if (isa<PointerType>(CI.getType())) { + if (isa<PointerType>(Op->getType())) { + // P1 = cast P2 --> <Copy/P1/P2> + Constraints.push_back(Constraint(Constraint::Copy, getNodeValue(CI), + getNode(CI.getOperand(0)))); + } else { + // P1 = cast int --> <Copy/P1/Univ> +#if 0 + Constraints.push_back(Constraint(Constraint::Copy, getNodeValue(CI), + UniversalSet)); +#else + getNodeValue(CI); +#endif + } + } else if (isa<PointerType>(Op->getType())) { + // int = cast P1 --> <Copy/Univ/P1> +#if 0 + Constraints.push_back(Constraint(Constraint::Copy, + UniversalSet, + getNode(CI.getOperand(0)))); +#else + getNode(CI.getOperand(0)); +#endif + } +} + +void Andersens::visitSelectInst(SelectInst &SI) { + if (isa<PointerType>(SI.getType())) { + unsigned SIN = getNodeValue(SI); + // P1 = select C, P2, P3 ---> <Copy/P1/P2>, <Copy/P1/P3> + Constraints.push_back(Constraint(Constraint::Copy, SIN, + getNode(SI.getOperand(1)))); + Constraints.push_back(Constraint(Constraint::Copy, SIN, + getNode(SI.getOperand(2)))); + } +} + +void Andersens::visitVAArg(VAArgInst &I) { + assert(0 && "vaarg not handled yet!"); +} + +/// AddConstraintsForCall - Add constraints for a call with actual arguments +/// specified by CS to the function specified by F. Note that the types of +/// arguments might not match up in the case where this is an indirect call and +/// the function pointer has been casted. If this is the case, do something +/// reasonable. +void Andersens::AddConstraintsForCall(CallSite CS, Function *F) { + Value *CallValue = CS.getCalledValue(); + bool IsDeref = F == NULL; + + // If this is a call to an external function, try to handle it directly to get + // some taste of context sensitivity. + if (F && F->isDeclaration() && AddConstraintsForExternalCall(CS, F)) + return; + + if (isa<PointerType>(CS.getType())) { + unsigned CSN = getNode(CS.getInstruction()); + if (!F || isa<PointerType>(F->getFunctionType()->getReturnType())) { + if (IsDeref) + Constraints.push_back(Constraint(Constraint::Load, CSN, + getNode(CallValue), CallReturnPos)); + else + Constraints.push_back(Constraint(Constraint::Copy, CSN, + getNode(CallValue) + CallReturnPos)); + } else { + // If the function returns a non-pointer value, handle this just like we + // treat a nonpointer cast to pointer. + Constraints.push_back(Constraint(Constraint::Copy, CSN, + UniversalSet)); + } + } else if (F && isa<PointerType>(F->getFunctionType()->getReturnType())) { +#if FULL_UNIVERSAL + Constraints.push_back(Constraint(Constraint::Copy, + UniversalSet, + getNode(CallValue) + CallReturnPos)); +#else + Constraints.push_back(Constraint(Constraint::Copy, + getNode(CallValue) + CallReturnPos, + UniversalSet)); +#endif + + + } + + CallSite::arg_iterator ArgI = CS.arg_begin(), ArgE = CS.arg_end(); + bool external = !F || F->isDeclaration(); + if (F) { + // Direct Call + Function::arg_iterator AI = F->arg_begin(), AE = F->arg_end(); + for (; AI != AE && ArgI != ArgE; ++AI, ++ArgI) + { +#if !FULL_UNIVERSAL + if (external && isa<PointerType>((*ArgI)->getType())) + { + // Add constraint that ArgI can now point to anything due to + // escaping, as can everything it points to. The second portion of + // this should be taken care of by universal = *universal + Constraints.push_back(Constraint(Constraint::Copy, + getNode(*ArgI), + UniversalSet)); + } +#endif + if (isa<PointerType>(AI->getType())) { + if (isa<PointerType>((*ArgI)->getType())) { + // Copy the actual argument into the formal argument. + Constraints.push_back(Constraint(Constraint::Copy, getNode(AI), + getNode(*ArgI))); + } else { + Constraints.push_back(Constraint(Constraint::Copy, getNode(AI), + UniversalSet)); + } + } else if (isa<PointerType>((*ArgI)->getType())) { +#if FULL_UNIVERSAL + Constraints.push_back(Constraint(Constraint::Copy, + UniversalSet, + getNode(*ArgI))); +#else + Constraints.push_back(Constraint(Constraint::Copy, + getNode(*ArgI), + UniversalSet)); +#endif + } + } + } else { + //Indirect Call + unsigned ArgPos = CallFirstArgPos; + for (; ArgI != ArgE; ++ArgI) { + if (isa<PointerType>((*ArgI)->getType())) { + // Copy the actual argument into the formal argument. + Constraints.push_back(Constraint(Constraint::Store, + getNode(CallValue), + getNode(*ArgI), ArgPos++)); + } else { + Constraints.push_back(Constraint(Constraint::Store, + getNode (CallValue), + UniversalSet, ArgPos++)); + } + } + } + // Copy all pointers passed through the varargs section to the varargs node. + if (F && F->getFunctionType()->isVarArg()) + for (; ArgI != ArgE; ++ArgI) + if (isa<PointerType>((*ArgI)->getType())) + Constraints.push_back(Constraint(Constraint::Copy, getVarargNode(F), + getNode(*ArgI))); + // If more arguments are passed in than we track, just drop them on the floor. +} + +void Andersens::visitCallSite(CallSite CS) { + if (isa<PointerType>(CS.getType())) + getNodeValue(*CS.getInstruction()); + + if (Function *F = CS.getCalledFunction()) { + AddConstraintsForCall(CS, F); + } else { + AddConstraintsForCall(CS, NULL); + } +} + +//===----------------------------------------------------------------------===// +// Constraint Solving Phase +//===----------------------------------------------------------------------===// + +/// intersects - Return true if the points-to set of this node intersects +/// with the points-to set of the specified node. +bool Andersens::Node::intersects(Node *N) const { + return PointsTo->intersects(N->PointsTo); +} + +/// intersectsIgnoring - Return true if the points-to set of this node +/// intersects with the points-to set of the specified node on any nodes +/// except for the specified node to ignore. +bool Andersens::Node::intersectsIgnoring(Node *N, unsigned Ignoring) const { + // TODO: If we are only going to call this with the same value for Ignoring, + // we should move the special values out of the points-to bitmap. + bool WeHadIt = PointsTo->test(Ignoring); + bool NHadIt = N->PointsTo->test(Ignoring); + bool Result = false; + if (WeHadIt) + PointsTo->reset(Ignoring); + if (NHadIt) + N->PointsTo->reset(Ignoring); + Result = PointsTo->intersects(N->PointsTo); + if (WeHadIt) + PointsTo->set(Ignoring); + if (NHadIt) + N->PointsTo->set(Ignoring); + return Result; +} + +void dumpToDOUT(SparseBitVector<> *bitmap) { +#ifndef NDEBUG + dump(*bitmap, DOUT); +#endif +} + + +/// Clump together address taken variables so that the points-to sets use up +/// less space and can be operated on faster. + +void Andersens::ClumpAddressTaken() { +#undef DEBUG_TYPE +#define DEBUG_TYPE "anders-aa-renumber" + std::vector<unsigned> Translate; + std::vector<Node> NewGraphNodes; + + Translate.resize(GraphNodes.size()); + unsigned NewPos = 0; + + for (unsigned i = 0; i < Constraints.size(); ++i) { + Constraint &C = Constraints[i]; + if (C.Type == Constraint::AddressOf) { + GraphNodes[C.Src].AddressTaken = true; + } + } + for (unsigned i = 0; i < NumberSpecialNodes; ++i) { + unsigned Pos = NewPos++; + Translate[i] = Pos; + NewGraphNodes.push_back(GraphNodes[i]); + DOUT << "Renumbering node " << i << " to node " << Pos << "\n"; + } + + // I believe this ends up being faster than making two vectors and splicing + // them. + for (unsigned i = NumberSpecialNodes; i < GraphNodes.size(); ++i) { + if (GraphNodes[i].AddressTaken) { + unsigned Pos = NewPos++; + Translate[i] = Pos; + NewGraphNodes.push_back(GraphNodes[i]); + DOUT << "Renumbering node " << i << " to node " << Pos << "\n"; + } + } + + for (unsigned i = NumberSpecialNodes; i < GraphNodes.size(); ++i) { + if (!GraphNodes[i].AddressTaken) { + unsigned Pos = NewPos++; + Translate[i] = Pos; + NewGraphNodes.push_back(GraphNodes[i]); + DOUT << "Renumbering node " << i << " to node " << Pos << "\n"; + } + } + + for (DenseMap<Value*, unsigned>::iterator Iter = ValueNodes.begin(); + Iter != ValueNodes.end(); + ++Iter) + Iter->second = Translate[Iter->second]; + + for (DenseMap<Value*, unsigned>::iterator Iter = ObjectNodes.begin(); + Iter != ObjectNodes.end(); + ++Iter) + Iter->second = Translate[Iter->second]; + + for (DenseMap<Function*, unsigned>::iterator Iter = ReturnNodes.begin(); + Iter != ReturnNodes.end(); + ++Iter) + Iter->second = Translate[Iter->second]; + + for (DenseMap<Function*, unsigned>::iterator Iter = VarargNodes.begin(); + Iter != VarargNodes.end(); + ++Iter) + Iter->second = Translate[Iter->second]; + + for (unsigned i = 0; i < Constraints.size(); ++i) { + Constraint &C = Constraints[i]; + C.Src = Translate[C.Src]; + C.Dest = Translate[C.Dest]; + } + + GraphNodes.swap(NewGraphNodes); +#undef DEBUG_TYPE +#define DEBUG_TYPE "anders-aa" +} + +/// The technique used here is described in "Exploiting Pointer and Location +/// Equivalence to Optimize Pointer Analysis. In the 14th International Static +/// Analysis Symposium (SAS), August 2007." It is known as the "HVN" algorithm, +/// and is equivalent to value numbering the collapsed constraint graph without +/// evaluating unions. This is used as a pre-pass to HU in order to resolve +/// first order pointer dereferences and speed up/reduce memory usage of HU. +/// Running both is equivalent to HRU without the iteration +/// HVN in more detail: +/// Imagine the set of constraints was simply straight line code with no loops +/// (we eliminate cycles, so there are no loops), such as: +/// E = &D +/// E = &C +/// E = F +/// F = G +/// G = F +/// Applying value numbering to this code tells us: +/// G == F == E +/// +/// For HVN, this is as far as it goes. We assign new value numbers to every +/// "address node", and every "reference node". +/// To get the optimal result for this, we use a DFS + SCC (since all nodes in a +/// cycle must have the same value number since the = operation is really +/// inclusion, not overwrite), and value number nodes we receive points-to sets +/// before we value our own node. +/// The advantage of HU over HVN is that HU considers the inclusion property, so +/// that if you have +/// E = &D +/// E = &C +/// E = F +/// F = G +/// F = &D +/// G = F +/// HU will determine that G == F == E. HVN will not, because it cannot prove +/// that the points to information ends up being the same because they all +/// receive &D from E anyway. + +void Andersens::HVN() { + DOUT << "Beginning HVN\n"; + // Build a predecessor graph. This is like our constraint graph with the + // edges going in the opposite direction, and there are edges for all the + // constraints, instead of just copy constraints. We also build implicit + // edges for constraints are implied but not explicit. I.E for the constraint + // a = &b, we add implicit edges *a = b. This helps us capture more cycles + for (unsigned i = 0, e = Constraints.size(); i != e; ++i) { + Constraint &C = Constraints[i]; + if (C.Type == Constraint::AddressOf) { + GraphNodes[C.Src].AddressTaken = true; + GraphNodes[C.Src].Direct = false; + + // Dest = &src edge + unsigned AdrNode = C.Src + FirstAdrNode; + if (!GraphNodes[C.Dest].PredEdges) + GraphNodes[C.Dest].PredEdges = new SparseBitVector<>; + GraphNodes[C.Dest].PredEdges->set(AdrNode); + + // *Dest = src edge + unsigned RefNode = C.Dest + FirstRefNode; + if (!GraphNodes[RefNode].ImplicitPredEdges) + GraphNodes[RefNode].ImplicitPredEdges = new SparseBitVector<>; + GraphNodes[RefNode].ImplicitPredEdges->set(C.Src); + } else if (C.Type == Constraint::Load) { + if (C.Offset == 0) { + // dest = *src edge + if (!GraphNodes[C.Dest].PredEdges) + GraphNodes[C.Dest].PredEdges = new SparseBitVector<>; + GraphNodes[C.Dest].PredEdges->set(C.Src + FirstRefNode); + } else { + GraphNodes[C.Dest].Direct = false; + } + } else if (C.Type == Constraint::Store) { + if (C.Offset == 0) { + // *dest = src edge + unsigned RefNode = C.Dest + FirstRefNode; + if (!GraphNodes[RefNode].PredEdges) + GraphNodes[RefNode].PredEdges = new SparseBitVector<>; + GraphNodes[RefNode].PredEdges->set(C.Src); + } + } else { + // Dest = Src edge and *Dest = *Src edge + if (!GraphNodes[C.Dest].PredEdges) + GraphNodes[C.Dest].PredEdges = new SparseBitVector<>; + GraphNodes[C.Dest].PredEdges->set(C.Src); + unsigned RefNode = C.Dest + FirstRefNode; + if (!GraphNodes[RefNode].ImplicitPredEdges) + GraphNodes[RefNode].ImplicitPredEdges = new SparseBitVector<>; + GraphNodes[RefNode].ImplicitPredEdges->set(C.Src + FirstRefNode); + } + } + PEClass = 1; + // Do SCC finding first to condense our predecessor graph + DFSNumber = 0; + Node2DFS.insert(Node2DFS.begin(), GraphNodes.size(), 0); + Node2Deleted.insert(Node2Deleted.begin(), GraphNodes.size(), false); + Node2Visited.insert(Node2Visited.begin(), GraphNodes.size(), false); + + for (unsigned i = 0; i < FirstRefNode; ++i) { + unsigned Node = VSSCCRep[i]; + if (!Node2Visited[Node]) + HVNValNum(Node); + } + for (BitVectorMap::iterator Iter = Set2PEClass.begin(); + Iter != Set2PEClass.end(); + ++Iter) + delete Iter->first; + Set2PEClass.clear(); + Node2DFS.clear(); + Node2Deleted.clear(); + Node2Visited.clear(); + DOUT << "Finished HVN\n"; + +} + +/// This is the workhorse of HVN value numbering. We combine SCC finding at the +/// same time because it's easy. +void Andersens::HVNValNum(unsigned NodeIndex) { + unsigned MyDFS = DFSNumber++; + Node *N = &GraphNodes[NodeIndex]; + Node2Visited[NodeIndex] = true; + Node2DFS[NodeIndex] = MyDFS; + + // First process all our explicit edges + if (N->PredEdges) + for (SparseBitVector<>::iterator Iter = N->PredEdges->begin(); + Iter != N->PredEdges->end(); + ++Iter) { + unsigned j = VSSCCRep[*Iter]; + if (!Node2Deleted[j]) { + if (!Node2Visited[j]) + HVNValNum(j); + if (Node2DFS[NodeIndex] > Node2DFS[j]) + Node2DFS[NodeIndex] = Node2DFS[j]; + } + } + + // Now process all the implicit edges + if (N->ImplicitPredEdges) + for (SparseBitVector<>::iterator Iter = N->ImplicitPredEdges->begin(); + Iter != N->ImplicitPredEdges->end(); + ++Iter) { + unsigned j = VSSCCRep[*Iter]; + if (!Node2Deleted[j]) { + if (!Node2Visited[j]) + HVNValNum(j); + if (Node2DFS[NodeIndex] > Node2DFS[j]) + Node2DFS[NodeIndex] = Node2DFS[j]; + } + } + + // See if we found any cycles + if (MyDFS == Node2DFS[NodeIndex]) { + while (!SCCStack.empty() && Node2DFS[SCCStack.top()] >= MyDFS) { + unsigned CycleNodeIndex = SCCStack.top(); + Node *CycleNode = &GraphNodes[CycleNodeIndex]; + VSSCCRep[CycleNodeIndex] = NodeIndex; + // Unify the nodes + N->Direct &= CycleNode->Direct; + + if (CycleNode->PredEdges) { + if (!N->PredEdges) + N->PredEdges = new SparseBitVector<>; + *(N->PredEdges) |= CycleNode->PredEdges; + delete CycleNode->PredEdges; + CycleNode->PredEdges = NULL; + } + if (CycleNode->ImplicitPredEdges) { + if (!N->ImplicitPredEdges) + N->ImplicitPredEdges = new SparseBitVector<>; + *(N->ImplicitPredEdges) |= CycleNode->ImplicitPredEdges; + delete CycleNode->ImplicitPredEdges; + CycleNode->ImplicitPredEdges = NULL; + } + + SCCStack.pop(); + } + + Node2Deleted[NodeIndex] = true; + + if (!N->Direct) { + GraphNodes[NodeIndex].PointerEquivLabel = PEClass++; + return; + } + + // Collect labels of successor nodes + bool AllSame = true; + unsigned First = ~0; + SparseBitVector<> *Labels = new SparseBitVector<>; + bool Used = false; + + if (N->PredEdges) + for (SparseBitVector<>::iterator Iter = N->PredEdges->begin(); + Iter != N->PredEdges->end(); + ++Iter) { + unsigned j = VSSCCRep[*Iter]; + unsigned Label = GraphNodes[j].PointerEquivLabel; + // Ignore labels that are equal to us or non-pointers + if (j == NodeIndex || Label == 0) + continue; + if (First == (unsigned)~0) + First = Label; + else if (First != Label) + AllSame = false; + Labels->set(Label); + } + + // We either have a non-pointer, a copy of an existing node, or a new node. + // Assign the appropriate pointer equivalence label. + if (Labels->empty()) { + GraphNodes[NodeIndex].PointerEquivLabel = 0; + } else if (AllSame) { + GraphNodes[NodeIndex].PointerEquivLabel = First; + } else { + GraphNodes[NodeIndex].PointerEquivLabel = Set2PEClass[Labels]; + if (GraphNodes[NodeIndex].PointerEquivLabel == 0) { + unsigned EquivClass = PEClass++; + Set2PEClass[Labels] = EquivClass; + GraphNodes[NodeIndex].PointerEquivLabel = EquivClass; + Used = true; + } + } + if (!Used) + delete Labels; + } else { + SCCStack.push(NodeIndex); + } +} + +/// The technique used here is described in "Exploiting Pointer and Location +/// Equivalence to Optimize Pointer Analysis. In the 14th International Static +/// Analysis Symposium (SAS), August 2007." It is known as the "HU" algorithm, +/// and is equivalent to value numbering the collapsed constraint graph +/// including evaluating unions. +void Andersens::HU() { + DOUT << "Beginning HU\n"; + // Build a predecessor graph. This is like our constraint graph with the + // edges going in the opposite direction, and there are edges for all the + // constraints, instead of just copy constraints. We also build implicit + // edges for constraints are implied but not explicit. I.E for the constraint + // a = &b, we add implicit edges *a = b. This helps us capture more cycles + for (unsigned i = 0, e = Constraints.size(); i != e; ++i) { + Constraint &C = Constraints[i]; + if (C.Type == Constraint::AddressOf) { + GraphNodes[C.Src].AddressTaken = true; + GraphNodes[C.Src].Direct = false; + + GraphNodes[C.Dest].PointsTo->set(C.Src); + // *Dest = src edge + unsigned RefNode = C.Dest + FirstRefNode; + if (!GraphNodes[RefNode].ImplicitPredEdges) + GraphNodes[RefNode].ImplicitPredEdges = new SparseBitVector<>; + GraphNodes[RefNode].ImplicitPredEdges->set(C.Src); + GraphNodes[C.Src].PointedToBy->set(C.Dest); + } else if (C.Type == Constraint::Load) { + if (C.Offset == 0) { + // dest = *src edge + if (!GraphNodes[C.Dest].PredEdges) + GraphNodes[C.Dest].PredEdges = new SparseBitVector<>; + GraphNodes[C.Dest].PredEdges->set(C.Src + FirstRefNode); + } else { + GraphNodes[C.Dest].Direct = false; + } + } else if (C.Type == Constraint::Store) { + if (C.Offset == 0) { + // *dest = src edge + unsigned RefNode = C.Dest + FirstRefNode; + if (!GraphNodes[RefNode].PredEdges) + GraphNodes[RefNode].PredEdges = new SparseBitVector<>; + GraphNodes[RefNode].PredEdges->set(C.Src); + } + } else { + // Dest = Src edge and *Dest = *Src edg + if (!GraphNodes[C.Dest].PredEdges) + GraphNodes[C.Dest].PredEdges = new SparseBitVector<>; + GraphNodes[C.Dest].PredEdges->set(C.Src); + unsigned RefNode = C.Dest + FirstRefNode; + if (!GraphNodes[RefNode].ImplicitPredEdges) + GraphNodes[RefNode].ImplicitPredEdges = new SparseBitVector<>; + GraphNodes[RefNode].ImplicitPredEdges->set(C.Src + FirstRefNode); + } + } + PEClass = 1; + // Do SCC finding first to condense our predecessor graph + DFSNumber = 0; + Node2DFS.insert(Node2DFS.begin(), GraphNodes.size(), 0); + Node2Deleted.insert(Node2Deleted.begin(), GraphNodes.size(), false); + Node2Visited.insert(Node2Visited.begin(), GraphNodes.size(), false); + + for (unsigned i = 0; i < FirstRefNode; ++i) { + if (FindNode(i) == i) { + unsigned Node = VSSCCRep[i]; + if (!Node2Visited[Node]) + Condense(Node); + } + } + + // Reset tables for actual labeling + Node2DFS.clear(); + Node2Visited.clear(); + Node2Deleted.clear(); + // Pre-grow our densemap so that we don't get really bad behavior + Set2PEClass.resize(GraphNodes.size()); + + // Visit the condensed graph and generate pointer equivalence labels. + Node2Visited.insert(Node2Visited.begin(), GraphNodes.size(), false); + for (unsigned i = 0; i < FirstRefNode; ++i) { + if (FindNode(i) == i) { + unsigned Node = VSSCCRep[i]; + if (!Node2Visited[Node]) + HUValNum(Node); + } + } + // PEClass nodes will be deleted by the deleting of N->PointsTo in our caller. + Set2PEClass.clear(); + DOUT << "Finished HU\n"; +} + + +/// Implementation of standard Tarjan SCC algorithm as modified by Nuutilla. +void Andersens::Condense(unsigned NodeIndex) { + unsigned MyDFS = DFSNumber++; + Node *N = &GraphNodes[NodeIndex]; + Node2Visited[NodeIndex] = true; + Node2DFS[NodeIndex] = MyDFS; + + // First process all our explicit edges + if (N->PredEdges) + for (SparseBitVector<>::iterator Iter = N->PredEdges->begin(); + Iter != N->PredEdges->end(); + ++Iter) { + unsigned j = VSSCCRep[*Iter]; + if (!Node2Deleted[j]) { + if (!Node2Visited[j]) + Condense(j); + if (Node2DFS[NodeIndex] > Node2DFS[j]) + Node2DFS[NodeIndex] = Node2DFS[j]; + } + } + + // Now process all the implicit edges + if (N->ImplicitPredEdges) + for (SparseBitVector<>::iterator Iter = N->ImplicitPredEdges->begin(); + Iter != N->ImplicitPredEdges->end(); + ++Iter) { + unsigned j = VSSCCRep[*Iter]; + if (!Node2Deleted[j]) { + if (!Node2Visited[j]) + Condense(j); + if (Node2DFS[NodeIndex] > Node2DFS[j]) + Node2DFS[NodeIndex] = Node2DFS[j]; + } + } + + // See if we found any cycles + if (MyDFS == Node2DFS[NodeIndex]) { + while (!SCCStack.empty() && Node2DFS[SCCStack.top()] >= MyDFS) { + unsigned CycleNodeIndex = SCCStack.top(); + Node *CycleNode = &GraphNodes[CycleNodeIndex]; + VSSCCRep[CycleNodeIndex] = NodeIndex; + // Unify the nodes + N->Direct &= CycleNode->Direct; + + *(N->PointsTo) |= CycleNode->PointsTo; + delete CycleNode->PointsTo; + CycleNode->PointsTo = NULL; + if (CycleNode->PredEdges) { + if (!N->PredEdges) + N->PredEdges = new SparseBitVector<>; + *(N->PredEdges) |= CycleNode->PredEdges; + delete CycleNode->PredEdges; + CycleNode->PredEdges = NULL; + } + if (CycleNode->ImplicitPredEdges) { + if (!N->ImplicitPredEdges) + N->ImplicitPredEdges = new SparseBitVector<>; + *(N->ImplicitPredEdges) |= CycleNode->ImplicitPredEdges; + delete CycleNode->ImplicitPredEdges; + CycleNode->ImplicitPredEdges = NULL; + } + SCCStack.pop(); + } + + Node2Deleted[NodeIndex] = true; + + // Set up number of incoming edges for other nodes + if (N->PredEdges) + for (SparseBitVector<>::iterator Iter = N->PredEdges->begin(); + Iter != N->PredEdges->end(); + ++Iter) + ++GraphNodes[VSSCCRep[*Iter]].NumInEdges; + } else { + SCCStack.push(NodeIndex); + } +} + +void Andersens::HUValNum(unsigned NodeIndex) { + Node *N = &GraphNodes[NodeIndex]; + Node2Visited[NodeIndex] = true; + + // Eliminate dereferences of non-pointers for those non-pointers we have + // already identified. These are ref nodes whose non-ref node: + // 1. Has already been visited determined to point to nothing (and thus, a + // dereference of it must point to nothing) + // 2. Any direct node with no predecessor edges in our graph and with no + // points-to set (since it can't point to anything either, being that it + // receives no points-to sets and has none). + if (NodeIndex >= FirstRefNode) { + unsigned j = VSSCCRep[FindNode(NodeIndex - FirstRefNode)]; + if ((Node2Visited[j] && !GraphNodes[j].PointerEquivLabel) + || (GraphNodes[j].Direct && !GraphNodes[j].PredEdges + && GraphNodes[j].PointsTo->empty())){ + return; + } + } + // Process all our explicit edges + if (N->PredEdges) + for (SparseBitVector<>::iterator Iter = N->PredEdges->begin(); + Iter != N->PredEdges->end(); + ++Iter) { + unsigned j = VSSCCRep[*Iter]; + if (!Node2Visited[j]) + HUValNum(j); + + // If this edge turned out to be the same as us, or got no pointer + // equivalence label (and thus points to nothing) , just decrement our + // incoming edges and continue. + if (j == NodeIndex || GraphNodes[j].PointerEquivLabel == 0) { + --GraphNodes[j].NumInEdges; + continue; + } + + *(N->PointsTo) |= GraphNodes[j].PointsTo; + + // If we didn't end up storing this in the hash, and we're done with all + // the edges, we don't need the points-to set anymore. + --GraphNodes[j].NumInEdges; + if (!GraphNodes[j].NumInEdges && !GraphNodes[j].StoredInHash) { + delete GraphNodes[j].PointsTo; + GraphNodes[j].PointsTo = NULL; + } + } + // If this isn't a direct node, generate a fresh variable. + if (!N->Direct) { + N->PointsTo->set(FirstRefNode + NodeIndex); + } + + // See If we have something equivalent to us, if not, generate a new + // equivalence class. + if (N->PointsTo->empty()) { + delete N->PointsTo; + N->PointsTo = NULL; + } else { + if (N->Direct) { + N->PointerEquivLabel = Set2PEClass[N->PointsTo]; + if (N->PointerEquivLabel == 0) { + unsigned EquivClass = PEClass++; + N->StoredInHash = true; + Set2PEClass[N->PointsTo] = EquivClass; + N->PointerEquivLabel = EquivClass; + } + } else { + N->PointerEquivLabel = PEClass++; + } + } +} + +/// Rewrite our list of constraints so that pointer equivalent nodes are +/// replaced by their the pointer equivalence class representative. +void Andersens::RewriteConstraints() { + std::vector<Constraint> NewConstraints; + DenseSet<Constraint, ConstraintKeyInfo> Seen; + + PEClass2Node.clear(); + PENLEClass2Node.clear(); + + // We may have from 1 to Graphnodes + 1 equivalence classes. + PEClass2Node.insert(PEClass2Node.begin(), GraphNodes.size() + 1, -1); + PENLEClass2Node.insert(PENLEClass2Node.begin(), GraphNodes.size() + 1, -1); + + // Rewrite constraints, ignoring non-pointer constraints, uniting equivalent + // nodes, and rewriting constraints to use the representative nodes. + for (unsigned i = 0, e = Constraints.size(); i != e; ++i) { + Constraint &C = Constraints[i]; + unsigned RHSNode = FindNode(C.Src); + unsigned LHSNode = FindNode(C.Dest); + unsigned RHSLabel = GraphNodes[VSSCCRep[RHSNode]].PointerEquivLabel; + unsigned LHSLabel = GraphNodes[VSSCCRep[LHSNode]].PointerEquivLabel; + + // First we try to eliminate constraints for things we can prove don't point + // to anything. + if (LHSLabel == 0) { + DEBUG(PrintNode(&GraphNodes[LHSNode])); + DOUT << " is a non-pointer, ignoring constraint.\n"; + continue; + } + if (RHSLabel == 0) { + DEBUG(PrintNode(&GraphNodes[RHSNode])); + DOUT << " is a non-pointer, ignoring constraint.\n"; + continue; + } + // This constraint may be useless, and it may become useless as we translate + // it. + if (C.Src == C.Dest && C.Type == Constraint::Copy) + continue; + + C.Src = FindEquivalentNode(RHSNode, RHSLabel); + C.Dest = FindEquivalentNode(FindNode(LHSNode), LHSLabel); + if ((C.Src == C.Dest && C.Type == Constraint::Copy) + || Seen.count(C)) + continue; + + Seen.insert(C); + NewConstraints.push_back(C); + } + Constraints.swap(NewConstraints); + PEClass2Node.clear(); +} + +/// See if we have a node that is pointer equivalent to the one being asked +/// about, and if so, unite them and return the equivalent node. Otherwise, +/// return the original node. +unsigned Andersens::FindEquivalentNode(unsigned NodeIndex, + unsigned NodeLabel) { + if (!GraphNodes[NodeIndex].AddressTaken) { + if (PEClass2Node[NodeLabel] != -1) { + // We found an existing node with the same pointer label, so unify them. + // We specifically request that Union-By-Rank not be used so that + // PEClass2Node[NodeLabel] U= NodeIndex and not the other way around. + return UniteNodes(PEClass2Node[NodeLabel], NodeIndex, false); + } else { + PEClass2Node[NodeLabel] = NodeIndex; + PENLEClass2Node[NodeLabel] = NodeIndex; + } + } else if (PENLEClass2Node[NodeLabel] == -1) { + PENLEClass2Node[NodeLabel] = NodeIndex; + } + + return NodeIndex; +} + +void Andersens::PrintLabels() const { + for (unsigned i = 0; i < GraphNodes.size(); ++i) { + if (i < FirstRefNode) { + PrintNode(&GraphNodes[i]); + } else if (i < FirstAdrNode) { + DOUT << "REF("; + PrintNode(&GraphNodes[i-FirstRefNode]); + DOUT <<")"; + } else { + DOUT << "ADR("; + PrintNode(&GraphNodes[i-FirstAdrNode]); + DOUT <<")"; + } + + DOUT << " has pointer label " << GraphNodes[i].PointerEquivLabel + << " and SCC rep " << VSSCCRep[i] + << " and is " << (GraphNodes[i].Direct ? "Direct" : "Not direct") + << "\n"; + } +} + +/// The technique used here is described in "The Ant and the +/// Grasshopper: Fast and Accurate Pointer Analysis for Millions of +/// Lines of Code. In Programming Language Design and Implementation +/// (PLDI), June 2007." It is known as the "HCD" (Hybrid Cycle +/// Detection) algorithm. It is called a hybrid because it performs an +/// offline analysis and uses its results during the solving (online) +/// phase. This is just the offline portion; the results of this +/// operation are stored in SDT and are later used in SolveContraints() +/// and UniteNodes(). +void Andersens::HCD() { + DOUT << "Starting HCD.\n"; + HCDSCCRep.resize(GraphNodes.size()); + + for (unsigned i = 0; i < GraphNodes.size(); ++i) { + GraphNodes[i].Edges = new SparseBitVector<>; + HCDSCCRep[i] = i; + } + + for (unsigned i = 0, e = Constraints.size(); i != e; ++i) { + Constraint &C = Constraints[i]; + assert (C.Src < GraphNodes.size() && C.Dest < GraphNodes.size()); + if (C.Type == Constraint::AddressOf) { + continue; + } else if (C.Type == Constraint::Load) { + if( C.Offset == 0 ) + GraphNodes[C.Dest].Edges->set(C.Src + FirstRefNode); + } else if (C.Type == Constraint::Store) { + if( C.Offset == 0 ) + GraphNodes[C.Dest + FirstRefNode].Edges->set(C.Src); + } else { + GraphNodes[C.Dest].Edges->set(C.Src); + } + } + + Node2DFS.insert(Node2DFS.begin(), GraphNodes.size(), 0); + Node2Deleted.insert(Node2Deleted.begin(), GraphNodes.size(), false); + Node2Visited.insert(Node2Visited.begin(), GraphNodes.size(), false); + SDT.insert(SDT.begin(), GraphNodes.size() / 2, -1); + + DFSNumber = 0; + for (unsigned i = 0; i < GraphNodes.size(); ++i) { + unsigned Node = HCDSCCRep[i]; + if (!Node2Deleted[Node]) + Search(Node); + } + + for (unsigned i = 0; i < GraphNodes.size(); ++i) + if (GraphNodes[i].Edges != NULL) { + delete GraphNodes[i].Edges; + GraphNodes[i].Edges = NULL; + } + + while( !SCCStack.empty() ) + SCCStack.pop(); + + Node2DFS.clear(); + Node2Visited.clear(); + Node2Deleted.clear(); + HCDSCCRep.clear(); + DOUT << "HCD complete.\n"; +} + +// Component of HCD: +// Use Nuutila's variant of Tarjan's algorithm to detect +// Strongly-Connected Components (SCCs). For non-trivial SCCs +// containing ref nodes, insert the appropriate information in SDT. +void Andersens::Search(unsigned Node) { + unsigned MyDFS = DFSNumber++; + + Node2Visited[Node] = true; + Node2DFS[Node] = MyDFS; + + for (SparseBitVector<>::iterator Iter = GraphNodes[Node].Edges->begin(), + End = GraphNodes[Node].Edges->end(); + Iter != End; + ++Iter) { + unsigned J = HCDSCCRep[*Iter]; + assert(GraphNodes[J].isRep() && "Debug check; must be representative"); + if (!Node2Deleted[J]) { + if (!Node2Visited[J]) + Search(J); + if (Node2DFS[Node] > Node2DFS[J]) + Node2DFS[Node] = Node2DFS[J]; + } + } + + if( MyDFS != Node2DFS[Node] ) { + SCCStack.push(Node); + return; + } + + // This node is the root of a SCC, so process it. + // + // If the SCC is "non-trivial" (not a singleton) and contains a reference + // node, we place this SCC into SDT. We unite the nodes in any case. + if (!SCCStack.empty() && Node2DFS[SCCStack.top()] >= MyDFS) { + SparseBitVector<> SCC; + + SCC.set(Node); + + bool Ref = (Node >= FirstRefNode); + + Node2Deleted[Node] = true; + + do { + unsigned P = SCCStack.top(); SCCStack.pop(); + Ref |= (P >= FirstRefNode); + SCC.set(P); + HCDSCCRep[P] = Node; + } while (!SCCStack.empty() && Node2DFS[SCCStack.top()] >= MyDFS); + + if (Ref) { + unsigned Rep = SCC.find_first(); + assert(Rep < FirstRefNode && "The SCC didn't have a non-Ref node!"); + + SparseBitVector<>::iterator i = SCC.begin(); + + // Skip over the non-ref nodes + while( *i < FirstRefNode ) + ++i; + + while( i != SCC.end() ) + SDT[ (*i++) - FirstRefNode ] = Rep; + } + } +} + + +/// Optimize the constraints by performing offline variable substitution and +/// other optimizations. +void Andersens::OptimizeConstraints() { + DOUT << "Beginning constraint optimization\n"; + + SDTActive = false; + + // Function related nodes need to stay in the same relative position and can't + // be location equivalent. + for (std::map<unsigned, unsigned>::iterator Iter = MaxK.begin(); + Iter != MaxK.end(); + ++Iter) { + for (unsigned i = Iter->first; + i != Iter->first + Iter->second; + ++i) { + GraphNodes[i].AddressTaken = true; + GraphNodes[i].Direct = false; + } + } + + ClumpAddressTaken(); + FirstRefNode = GraphNodes.size(); + FirstAdrNode = FirstRefNode + GraphNodes.size(); + GraphNodes.insert(GraphNodes.end(), 2 * GraphNodes.size(), + Node(false)); + VSSCCRep.resize(GraphNodes.size()); + for (unsigned i = 0; i < GraphNodes.size(); ++i) { + VSSCCRep[i] = i; + } + HVN(); + for (unsigned i = 0; i < GraphNodes.size(); ++i) { + Node *N = &GraphNodes[i]; + delete N->PredEdges; + N->PredEdges = NULL; + delete N->ImplicitPredEdges; + N->ImplicitPredEdges = NULL; + } +#undef DEBUG_TYPE +#define DEBUG_TYPE "anders-aa-labels" + DEBUG(PrintLabels()); +#undef DEBUG_TYPE +#define DEBUG_TYPE "anders-aa" + RewriteConstraints(); + // Delete the adr nodes. + GraphNodes.resize(FirstRefNode * 2); + + // Now perform HU + for (unsigned i = 0; i < GraphNodes.size(); ++i) { + Node *N = &GraphNodes[i]; + if (FindNode(i) == i) { + N->PointsTo = new SparseBitVector<>; + N->PointedToBy = new SparseBitVector<>; + // Reset our labels + } + VSSCCRep[i] = i; + N->PointerEquivLabel = 0; + } + HU(); +#undef DEBUG_TYPE +#define DEBUG_TYPE "anders-aa-labels" + DEBUG(PrintLabels()); +#undef DEBUG_TYPE +#define DEBUG_TYPE "anders-aa" + RewriteConstraints(); + for (unsigned i = 0; i < GraphNodes.size(); ++i) { + if (FindNode(i) == i) { + Node *N = &GraphNodes[i]; + delete N->PointsTo; + N->PointsTo = NULL; + delete N->PredEdges; + N->PredEdges = NULL; + delete N->ImplicitPredEdges; + N->ImplicitPredEdges = NULL; + delete N->PointedToBy; + N->PointedToBy = NULL; + } + } + + // perform Hybrid Cycle Detection (HCD) + HCD(); + SDTActive = true; + + // No longer any need for the upper half of GraphNodes (for ref nodes). + GraphNodes.erase(GraphNodes.begin() + FirstRefNode, GraphNodes.end()); + + // HCD complete. + + DOUT << "Finished constraint optimization\n"; + FirstRefNode = 0; + FirstAdrNode = 0; +} + +/// Unite pointer but not location equivalent variables, now that the constraint +/// graph is built. +void Andersens::UnitePointerEquivalences() { + DOUT << "Uniting remaining pointer equivalences\n"; + for (unsigned i = 0; i < GraphNodes.size(); ++i) { + if (GraphNodes[i].AddressTaken && GraphNodes[i].isRep()) { + unsigned Label = GraphNodes[i].PointerEquivLabel; + + if (Label && PENLEClass2Node[Label] != -1) + UniteNodes(i, PENLEClass2Node[Label]); + } + } + DOUT << "Finished remaining pointer equivalences\n"; + PENLEClass2Node.clear(); +} + +/// Create the constraint graph used for solving points-to analysis. +/// +void Andersens::CreateConstraintGraph() { + for (unsigned i = 0, e = Constraints.size(); i != e; ++i) { + Constraint &C = Constraints[i]; + assert (C.Src < GraphNodes.size() && C.Dest < GraphNodes.size()); + if (C.Type == Constraint::AddressOf) + GraphNodes[C.Dest].PointsTo->set(C.Src); + else if (C.Type == Constraint::Load) + GraphNodes[C.Src].Constraints.push_back(C); + else if (C.Type == Constraint::Store) + GraphNodes[C.Dest].Constraints.push_back(C); + else if (C.Offset != 0) + GraphNodes[C.Src].Constraints.push_back(C); + else + GraphNodes[C.Src].Edges->set(C.Dest); + } +} + +// Perform DFS and cycle detection. +bool Andersens::QueryNode(unsigned Node) { + assert(GraphNodes[Node].isRep() && "Querying a non-rep node"); + unsigned OurDFS = ++DFSNumber; + SparseBitVector<> ToErase; + SparseBitVector<> NewEdges; + Tarjan2DFS[Node] = OurDFS; + + // Changed denotes a change from a recursive call that we will bubble up. + // Merged is set if we actually merge a node ourselves. + bool Changed = false, Merged = false; + + for (SparseBitVector<>::iterator bi = GraphNodes[Node].Edges->begin(); + bi != GraphNodes[Node].Edges->end(); + ++bi) { + unsigned RepNode = FindNode(*bi); + // If this edge points to a non-representative node but we are + // already planning to add an edge to its representative, we have no + // need for this edge anymore. + if (RepNode != *bi && NewEdges.test(RepNode)){ + ToErase.set(*bi); + continue; + } + + // Continue about our DFS. + if (!Tarjan2Deleted[RepNode]){ + if (Tarjan2DFS[RepNode] == 0) { + Changed |= QueryNode(RepNode); + // May have been changed by QueryNode + RepNode = FindNode(RepNode); + } + if (Tarjan2DFS[RepNode] < Tarjan2DFS[Node]) + Tarjan2DFS[Node] = Tarjan2DFS[RepNode]; + } + + // We may have just discovered that this node is part of a cycle, in + // which case we can also erase it. + if (RepNode != *bi) { + ToErase.set(*bi); + NewEdges.set(RepNode); + } + } + + GraphNodes[Node].Edges->intersectWithComplement(ToErase); + GraphNodes[Node].Edges |= NewEdges; + + // If this node is a root of a non-trivial SCC, place it on our + // worklist to be processed. + if (OurDFS == Tarjan2DFS[Node]) { + while (!SCCStack.empty() && Tarjan2DFS[SCCStack.top()] >= OurDFS) { + Node = UniteNodes(Node, SCCStack.top()); + + SCCStack.pop(); + Merged = true; + } + Tarjan2Deleted[Node] = true; + + if (Merged) + NextWL->insert(&GraphNodes[Node]); + } else { + SCCStack.push(Node); + } + + return(Changed | Merged); +} + +/// SolveConstraints - This stage iteratively processes the constraints list +/// propagating constraints (adding edges to the Nodes in the points-to graph) +/// until a fixed point is reached. +/// +/// We use a variant of the technique called "Lazy Cycle Detection", which is +/// described in "The Ant and the Grasshopper: Fast and Accurate Pointer +/// Analysis for Millions of Lines of Code. In Programming Language Design and +/// Implementation (PLDI), June 2007." +/// The paper describes performing cycle detection one node at a time, which can +/// be expensive if there are no cycles, but there are long chains of nodes that +/// it heuristically believes are cycles (because it will DFS from each node +/// without state from previous nodes). +/// Instead, we use the heuristic to build a worklist of nodes to check, then +/// cycle detect them all at the same time to do this more cheaply. This +/// catches cycles slightly later than the original technique did, but does it +/// make significantly cheaper. + +void Andersens::SolveConstraints() { + CurrWL = &w1; + NextWL = &w2; + + OptimizeConstraints(); +#undef DEBUG_TYPE +#define DEBUG_TYPE "anders-aa-constraints" + DEBUG(PrintConstraints()); +#undef DEBUG_TYPE +#define DEBUG_TYPE "anders-aa" + + for (unsigned i = 0; i < GraphNodes.size(); ++i) { + Node *N = &GraphNodes[i]; + N->PointsTo = new SparseBitVector<>; + N->OldPointsTo = new SparseBitVector<>; + N->Edges = new SparseBitVector<>; + } + CreateConstraintGraph(); + UnitePointerEquivalences(); + assert(SCCStack.empty() && "SCC Stack should be empty by now!"); + Node2DFS.clear(); + Node2Deleted.clear(); + Node2DFS.insert(Node2DFS.begin(), GraphNodes.size(), 0); + Node2Deleted.insert(Node2Deleted.begin(), GraphNodes.size(), false); + DFSNumber = 0; + DenseSet<Constraint, ConstraintKeyInfo> Seen; + DenseSet<std::pair<unsigned,unsigned>, PairKeyInfo> EdgesChecked; + + // Order graph and add initial nodes to work list. + for (unsigned i = 0; i < GraphNodes.size(); ++i) { + Node *INode = &GraphNodes[i]; + + // Add to work list if it's a representative and can contribute to the + // calculation right now. + if (INode->isRep() && !INode->PointsTo->empty() + && (!INode->Edges->empty() || !INode->Constraints.empty())) { + INode->Stamp(); + CurrWL->insert(INode); + } + } + std::queue<unsigned int> TarjanWL; +#if !FULL_UNIVERSAL + // "Rep and special variables" - in order for HCD to maintain conservative + // results when !FULL_UNIVERSAL, we need to treat the special variables in + // the same way that the !FULL_UNIVERSAL tweak does throughout the rest of + // the analysis - it's ok to add edges from the special nodes, but never + // *to* the special nodes. + std::vector<unsigned int> RSV; +#endif + while( !CurrWL->empty() ) { + DOUT << "Starting iteration #" << ++NumIters << "\n"; + + Node* CurrNode; + unsigned CurrNodeIndex; + + // Actual cycle checking code. We cycle check all of the lazy cycle + // candidates from the last iteration in one go. + if (!TarjanWL.empty()) { + DFSNumber = 0; + + Tarjan2DFS.clear(); + Tarjan2Deleted.clear(); + while (!TarjanWL.empty()) { + unsigned int ToTarjan = TarjanWL.front(); + TarjanWL.pop(); + if (!Tarjan2Deleted[ToTarjan] + && GraphNodes[ToTarjan].isRep() + && Tarjan2DFS[ToTarjan] == 0) + QueryNode(ToTarjan); + } + } + + // Add to work list if it's a representative and can contribute to the + // calculation right now. + while( (CurrNode = CurrWL->pop()) != NULL ) { + CurrNodeIndex = CurrNode - &GraphNodes[0]; + CurrNode->Stamp(); + + + // Figure out the changed points to bits + SparseBitVector<> CurrPointsTo; + CurrPointsTo.intersectWithComplement(CurrNode->PointsTo, + CurrNode->OldPointsTo); + if (CurrPointsTo.empty()) + continue; + + *(CurrNode->OldPointsTo) |= CurrPointsTo; + + // Check the offline-computed equivalencies from HCD. + bool SCC = false; + unsigned Rep; + + if (SDT[CurrNodeIndex] >= 0) { + SCC = true; + Rep = FindNode(SDT[CurrNodeIndex]); + +#if !FULL_UNIVERSAL + RSV.clear(); +#endif + for (SparseBitVector<>::iterator bi = CurrPointsTo.begin(); + bi != CurrPointsTo.end(); ++bi) { + unsigned Node = FindNode(*bi); +#if !FULL_UNIVERSAL + if (Node < NumberSpecialNodes) { + RSV.push_back(Node); + continue; + } +#endif + Rep = UniteNodes(Rep,Node); + } +#if !FULL_UNIVERSAL + RSV.push_back(Rep); +#endif + + NextWL->insert(&GraphNodes[Rep]); + + if ( ! CurrNode->isRep() ) + continue; + } + + Seen.clear(); + + /* Now process the constraints for this node. */ + for (std::list<Constraint>::iterator li = CurrNode->Constraints.begin(); + li != CurrNode->Constraints.end(); ) { + li->Src = FindNode(li->Src); + li->Dest = FindNode(li->Dest); + + // Delete redundant constraints + if( Seen.count(*li) ) { + std::list<Constraint>::iterator lk = li; li++; + + CurrNode->Constraints.erase(lk); + ++NumErased; + continue; + } + Seen.insert(*li); + + // Src and Dest will be the vars we are going to process. + // This may look a bit ugly, but what it does is allow us to process + // both store and load constraints with the same code. + // Load constraints say that every member of our RHS solution has K + // added to it, and that variable gets an edge to LHS. We also union + // RHS+K's solution into the LHS solution. + // Store constraints say that every member of our LHS solution has K + // added to it, and that variable gets an edge from RHS. We also union + // RHS's solution into the LHS+K solution. + unsigned *Src; + unsigned *Dest; + unsigned K = li->Offset; + unsigned CurrMember; + if (li->Type == Constraint::Load) { + Src = &CurrMember; + Dest = &li->Dest; + } else if (li->Type == Constraint::Store) { + Src = &li->Src; + Dest = &CurrMember; + } else { + // TODO Handle offseted copy constraint + li++; + continue; + } + + // See if we can use Hybrid Cycle Detection (that is, check + // if it was a statically detected offline equivalence that + // involves pointers; if so, remove the redundant constraints). + if( SCC && K == 0 ) { +#if FULL_UNIVERSAL + CurrMember = Rep; + + if (GraphNodes[*Src].Edges->test_and_set(*Dest)) + if (GraphNodes[*Dest].PointsTo |= *(GraphNodes[*Src].PointsTo)) + NextWL->insert(&GraphNodes[*Dest]); +#else + for (unsigned i=0; i < RSV.size(); ++i) { + CurrMember = RSV[i]; + + if (*Dest < NumberSpecialNodes) + continue; + if (GraphNodes[*Src].Edges->test_and_set(*Dest)) + if (GraphNodes[*Dest].PointsTo |= *(GraphNodes[*Src].PointsTo)) + NextWL->insert(&GraphNodes[*Dest]); + } +#endif + // since all future elements of the points-to set will be + // equivalent to the current ones, the complex constraints + // become redundant. + // + std::list<Constraint>::iterator lk = li; li++; +#if !FULL_UNIVERSAL + // In this case, we can still erase the constraints when the + // elements of the points-to sets are referenced by *Dest, + // but not when they are referenced by *Src (i.e. for a Load + // constraint). This is because if another special variable is + // put into the points-to set later, we still need to add the + // new edge from that special variable. + if( lk->Type != Constraint::Load) +#endif + GraphNodes[CurrNodeIndex].Constraints.erase(lk); + } else { + const SparseBitVector<> &Solution = CurrPointsTo; + + for (SparseBitVector<>::iterator bi = Solution.begin(); + bi != Solution.end(); + ++bi) { + CurrMember = *bi; + + // Need to increment the member by K since that is where we are + // supposed to copy to/from. Note that in positive weight cycles, + // which occur in address taking of fields, K can go past + // MaxK[CurrMember] elements, even though that is all it could point + // to. + if (K > 0 && K > MaxK[CurrMember]) + continue; + else + CurrMember = FindNode(CurrMember + K); + + // Add an edge to the graph, so we can just do regular + // bitmap ior next time. It may also let us notice a cycle. +#if !FULL_UNIVERSAL + if (*Dest < NumberSpecialNodes) + continue; +#endif + if (GraphNodes[*Src].Edges->test_and_set(*Dest)) + if (GraphNodes[*Dest].PointsTo |= *(GraphNodes[*Src].PointsTo)) + NextWL->insert(&GraphNodes[*Dest]); + + } + li++; + } + } + SparseBitVector<> NewEdges; + SparseBitVector<> ToErase; + + // Now all we have left to do is propagate points-to info along the + // edges, erasing the redundant edges. + for (SparseBitVector<>::iterator bi = CurrNode->Edges->begin(); + bi != CurrNode->Edges->end(); + ++bi) { + + unsigned DestVar = *bi; + unsigned Rep = FindNode(DestVar); + + // If we ended up with this node as our destination, or we've already + // got an edge for the representative, delete the current edge. + if (Rep == CurrNodeIndex || + (Rep != DestVar && NewEdges.test(Rep))) { + ToErase.set(DestVar); + continue; + } + + std::pair<unsigned,unsigned> edge(CurrNodeIndex,Rep); + + // This is where we do lazy cycle detection. + // If this is a cycle candidate (equal points-to sets and this + // particular edge has not been cycle-checked previously), add to the + // list to check for cycles on the next iteration. + if (!EdgesChecked.count(edge) && + *(GraphNodes[Rep].PointsTo) == *(CurrNode->PointsTo)) { + EdgesChecked.insert(edge); + TarjanWL.push(Rep); + } + // Union the points-to sets into the dest +#if !FULL_UNIVERSAL + if (Rep >= NumberSpecialNodes) +#endif + if (GraphNodes[Rep].PointsTo |= CurrPointsTo) { + NextWL->insert(&GraphNodes[Rep]); + } + // If this edge's destination was collapsed, rewrite the edge. + if (Rep != DestVar) { + ToErase.set(DestVar); + NewEdges.set(Rep); + } + } + CurrNode->Edges->intersectWithComplement(ToErase); + CurrNode->Edges |= NewEdges; + } + + // Switch to other work list. + WorkList* t = CurrWL; CurrWL = NextWL; NextWL = t; + } + + + Node2DFS.clear(); + Node2Deleted.clear(); + for (unsigned i = 0; i < GraphNodes.size(); ++i) { + Node *N = &GraphNodes[i]; + delete N->OldPointsTo; + delete N->Edges; + } + SDTActive = false; + SDT.clear(); +} + +//===----------------------------------------------------------------------===// +// Union-Find +//===----------------------------------------------------------------------===// + +// Unite nodes First and Second, returning the one which is now the +// representative node. First and Second are indexes into GraphNodes +unsigned Andersens::UniteNodes(unsigned First, unsigned Second, + bool UnionByRank) { + assert (First < GraphNodes.size() && Second < GraphNodes.size() && + "Attempting to merge nodes that don't exist"); + + Node *FirstNode = &GraphNodes[First]; + Node *SecondNode = &GraphNodes[Second]; + + assert (SecondNode->isRep() && FirstNode->isRep() && + "Trying to unite two non-representative nodes!"); + if (First == Second) + return First; + + if (UnionByRank) { + int RankFirst = (int) FirstNode ->NodeRep; + int RankSecond = (int) SecondNode->NodeRep; + + // Rank starts at -1 and gets decremented as it increases. + // Translation: higher rank, lower NodeRep value, which is always negative. + if (RankFirst > RankSecond) { + unsigned t = First; First = Second; Second = t; + Node* tp = FirstNode; FirstNode = SecondNode; SecondNode = tp; + } else if (RankFirst == RankSecond) { + FirstNode->NodeRep = (unsigned) (RankFirst - 1); + } + } + + SecondNode->NodeRep = First; +#if !FULL_UNIVERSAL + if (First >= NumberSpecialNodes) +#endif + if (FirstNode->PointsTo && SecondNode->PointsTo) + FirstNode->PointsTo |= *(SecondNode->PointsTo); + if (FirstNode->Edges && SecondNode->Edges) + FirstNode->Edges |= *(SecondNode->Edges); + if (!SecondNode->Constraints.empty()) + FirstNode->Constraints.splice(FirstNode->Constraints.begin(), + SecondNode->Constraints); + if (FirstNode->OldPointsTo) { + delete FirstNode->OldPointsTo; + FirstNode->OldPointsTo = new SparseBitVector<>; + } + + // Destroy interesting parts of the merged-from node. + delete SecondNode->OldPointsTo; + delete SecondNode->Edges; + delete SecondNode->PointsTo; + SecondNode->Edges = NULL; + SecondNode->PointsTo = NULL; + SecondNode->OldPointsTo = NULL; + + NumUnified++; + DOUT << "Unified Node "; + DEBUG(PrintNode(FirstNode)); + DOUT << " and Node "; + DEBUG(PrintNode(SecondNode)); + DOUT << "\n"; + + if (SDTActive) + if (SDT[Second] >= 0) { + if (SDT[First] < 0) + SDT[First] = SDT[Second]; + else { + UniteNodes( FindNode(SDT[First]), FindNode(SDT[Second]) ); + First = FindNode(First); + } + } + + return First; +} + +// Find the index into GraphNodes of the node representing Node, performing +// path compression along the way +unsigned Andersens::FindNode(unsigned NodeIndex) { + assert (NodeIndex < GraphNodes.size() + && "Attempting to find a node that can't exist"); + Node *N = &GraphNodes[NodeIndex]; + if (N->isRep()) + return NodeIndex; + else + return (N->NodeRep = FindNode(N->NodeRep)); +} + +// Find the index into GraphNodes of the node representing Node, +// don't perform path compression along the way (for Print) +unsigned Andersens::FindNode(unsigned NodeIndex) const { + assert (NodeIndex < GraphNodes.size() + && "Attempting to find a node that can't exist"); + const Node *N = &GraphNodes[NodeIndex]; + if (N->isRep()) + return NodeIndex; + else + return FindNode(N->NodeRep); +} + +//===----------------------------------------------------------------------===// +// Debugging Output +//===----------------------------------------------------------------------===// + +void Andersens::PrintNode(const Node *N) const { + if (N == &GraphNodes[UniversalSet]) { + cerr << "<universal>"; + return; + } else if (N == &GraphNodes[NullPtr]) { + cerr << "<nullptr>"; + return; + } else if (N == &GraphNodes[NullObject]) { + cerr << "<null>"; + return; + } + if (!N->getValue()) { + cerr << "artificial" << (intptr_t) N; + return; + } + + assert(N->getValue() != 0 && "Never set node label!"); + Value *V = N->getValue(); + if (Function *F = dyn_cast<Function>(V)) { + if (isa<PointerType>(F->getFunctionType()->getReturnType()) && + N == &GraphNodes[getReturnNode(F)]) { + cerr << F->getName() << ":retval"; + return; + } else if (F->getFunctionType()->isVarArg() && + N == &GraphNodes[getVarargNode(F)]) { + cerr << F->getName() << ":vararg"; + return; + } + } + + if (Instruction *I = dyn_cast<Instruction>(V)) + cerr << I->getParent()->getParent()->getName() << ":"; + else if (Argument *Arg = dyn_cast<Argument>(V)) + cerr << Arg->getParent()->getName() << ":"; + + if (V->hasName()) + cerr << V->getName(); + else + cerr << "(unnamed)"; + + if (isa<GlobalValue>(V) || isa<AllocationInst>(V)) + if (N == &GraphNodes[getObject(V)]) + cerr << "<mem>"; +} +void Andersens::PrintConstraint(const Constraint &C) const { + if (C.Type == Constraint::Store) { + cerr << "*"; + if (C.Offset != 0) + cerr << "("; + } + PrintNode(&GraphNodes[C.Dest]); + if (C.Type == Constraint::Store && C.Offset != 0) + cerr << " + " << C.Offset << ")"; + cerr << " = "; + if (C.Type == Constraint::Load) { + cerr << "*"; + if (C.Offset != 0) + cerr << "("; + } + else if (C.Type == Constraint::AddressOf) + cerr << "&"; + PrintNode(&GraphNodes[C.Src]); + if (C.Offset != 0 && C.Type != Constraint::Store) + cerr << " + " << C.Offset; + if (C.Type == Constraint::Load && C.Offset != 0) + cerr << ")"; + cerr << "\n"; +} + +void Andersens::PrintConstraints() const { + cerr << "Constraints:\n"; + + for (unsigned i = 0, e = Constraints.size(); i != e; ++i) + PrintConstraint(Constraints[i]); +} + +void Andersens::PrintPointsToGraph() const { + cerr << "Points-to graph:\n"; + for (unsigned i = 0, e = GraphNodes.size(); i != e; ++i) { + const Node *N = &GraphNodes[i]; + if (FindNode(i) != i) { + PrintNode(N); + cerr << "\t--> same as "; + PrintNode(&GraphNodes[FindNode(i)]); + cerr << "\n"; + } else { + cerr << "[" << (N->PointsTo->count()) << "] "; + PrintNode(N); + cerr << "\t--> "; + + bool first = true; + for (SparseBitVector<>::iterator bi = N->PointsTo->begin(); + bi != N->PointsTo->end(); + ++bi) { + if (!first) + cerr << ", "; + PrintNode(&GraphNodes[*bi]); + first = false; + } + cerr << "\n"; + } + } +} diff --git a/lib/Analysis/IPA/CMakeLists.txt b/lib/Analysis/IPA/CMakeLists.txt new file mode 100644 index 0000000..1ebb0be --- /dev/null +++ b/lib/Analysis/IPA/CMakeLists.txt @@ -0,0 +1,7 @@ +add_llvm_library(LLVMipa + Andersens.cpp + CallGraph.cpp + CallGraphSCCPass.cpp + FindUsedTypes.cpp + GlobalsModRef.cpp + ) diff --git a/lib/Analysis/IPA/CallGraph.cpp b/lib/Analysis/IPA/CallGraph.cpp new file mode 100644 index 0000000..6dabcdb --- /dev/null +++ b/lib/Analysis/IPA/CallGraph.cpp @@ -0,0 +1,314 @@ +//===- CallGraph.cpp - Build a Module's call graph ------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the CallGraph class and provides the BasicCallGraph +// default implementation. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Module.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Support/CallSite.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Streams.h" +#include <ostream> +using namespace llvm; + +namespace { + +//===----------------------------------------------------------------------===// +// BasicCallGraph class definition +// +class VISIBILITY_HIDDEN BasicCallGraph : public CallGraph, public ModulePass { + // Root is root of the call graph, or the external node if a 'main' function + // couldn't be found. + // + CallGraphNode *Root; + + // ExternalCallingNode - This node has edges to all external functions and + // those internal functions that have their address taken. + CallGraphNode *ExternalCallingNode; + + // CallsExternalNode - This node has edges to it from all functions making + // indirect calls or calling an external function. + CallGraphNode *CallsExternalNode; + +public: + static char ID; // Class identification, replacement for typeinfo + BasicCallGraph() : ModulePass(&ID), Root(0), + ExternalCallingNode(0), CallsExternalNode(0) {} + + // runOnModule - Compute the call graph for the specified module. + virtual bool runOnModule(Module &M) { + CallGraph::initialize(M); + + ExternalCallingNode = getOrInsertFunction(0); + CallsExternalNode = new CallGraphNode(0); + Root = 0; + + // Add every function to the call graph... + for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) + addToCallGraph(I); + + // If we didn't find a main function, use the external call graph node + if (Root == 0) Root = ExternalCallingNode; + + return false; + } + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + } + + void print(std::ostream *o, const Module *M) const { + if (o) print(*o, M); + } + + virtual void print(std::ostream &o, const Module *M) const { + o << "CallGraph Root is: "; + if (Function *F = getRoot()->getFunction()) + o << F->getName() << "\n"; + else + o << "<<null function: 0x" << getRoot() << ">>\n"; + + CallGraph::print(o, M); + } + + virtual void releaseMemory() { + destroy(); + } + + /// dump - Print out this call graph. + /// + inline void dump() const { + print(cerr, Mod); + } + + CallGraphNode* getExternalCallingNode() const { return ExternalCallingNode; } + CallGraphNode* getCallsExternalNode() const { return CallsExternalNode; } + + // getRoot - Return the root of the call graph, which is either main, or if + // main cannot be found, the external node. + // + CallGraphNode *getRoot() { return Root; } + const CallGraphNode *getRoot() const { return Root; } + +private: + //===--------------------------------------------------------------------- + // Implementation of CallGraph construction + // + + // addToCallGraph - Add a function to the call graph, and link the node to all + // of the functions that it calls. + // + void addToCallGraph(Function *F) { + CallGraphNode *Node = getOrInsertFunction(F); + + // If this function has external linkage, anything could call it. + if (!F->hasLocalLinkage()) { + ExternalCallingNode->addCalledFunction(CallSite(), Node); + + // Found the entry point? + if (F->getName() == "main") { + if (Root) // Found multiple external mains? Don't pick one. + Root = ExternalCallingNode; + else + Root = Node; // Found a main, keep track of it! + } + } + + // Loop over all of the users of the function, looking for non-call uses. + for (Value::use_iterator I = F->use_begin(), E = F->use_end(); I != E; ++I) + if ((!isa<CallInst>(I) && !isa<InvokeInst>(I)) + || !CallSite(cast<Instruction>(I)).isCallee(I)) { + // Not a call, or being used as a parameter rather than as the callee. + ExternalCallingNode->addCalledFunction(CallSite(), Node); + break; + } + + // If this function is not defined in this translation unit, it could call + // anything. + if (F->isDeclaration() && !F->isIntrinsic()) + Node->addCalledFunction(CallSite(), CallsExternalNode); + + // Look for calls by this function. + for (Function::iterator BB = F->begin(), BBE = F->end(); BB != BBE; ++BB) + for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); + II != IE; ++II) { + CallSite CS = CallSite::get(II); + if (CS.getInstruction() && !isa<DbgInfoIntrinsic>(II)) { + const Function *Callee = CS.getCalledFunction(); + if (Callee) + Node->addCalledFunction(CS, getOrInsertFunction(Callee)); + else + Node->addCalledFunction(CS, CallsExternalNode); + } + } + } + + // + // destroy - Release memory for the call graph + virtual void destroy() { + /// CallsExternalNode is not in the function map, delete it explicitly. + delete CallsExternalNode; + CallsExternalNode = 0; + CallGraph::destroy(); + } +}; + +} //End anonymous namespace + +static RegisterAnalysisGroup<CallGraph> X("Call Graph"); +static RegisterPass<BasicCallGraph> +Y("basiccg", "Basic CallGraph Construction", false, true); +static RegisterAnalysisGroup<CallGraph, true> Z(Y); + +char CallGraph::ID = 0; +char BasicCallGraph::ID = 0; + +void CallGraph::initialize(Module &M) { + Mod = &M; +} + +void CallGraph::destroy() { + if (!FunctionMap.empty()) { + for (FunctionMapTy::iterator I = FunctionMap.begin(), E = FunctionMap.end(); + I != E; ++I) + delete I->second; + FunctionMap.clear(); + } +} + +void CallGraph::print(std::ostream &OS, const Module *M) const { + for (CallGraph::const_iterator I = begin(), E = end(); I != E; ++I) + I->second->print(OS); +} + +void CallGraph::dump() const { + print(cerr, 0); +} + +//===----------------------------------------------------------------------===// +// Implementations of public modification methods +// + +// removeFunctionFromModule - Unlink the function from this module, returning +// it. Because this removes the function from the module, the call graph node +// is destroyed. This is only valid if the function does not call any other +// functions (ie, there are no edges in it's CGN). The easiest way to do this +// is to dropAllReferences before calling this. +// +Function *CallGraph::removeFunctionFromModule(CallGraphNode *CGN) { + assert(CGN->CalledFunctions.empty() && "Cannot remove function from call " + "graph if it references other functions!"); + Function *F = CGN->getFunction(); // Get the function for the call graph node + delete CGN; // Delete the call graph node for this func + FunctionMap.erase(F); // Remove the call graph node from the map + + Mod->getFunctionList().remove(F); + return F; +} + +// changeFunction - This method changes the function associated with this +// CallGraphNode, for use by transformations that need to change the prototype +// of a Function (thus they must create a new Function and move the old code +// over). +void CallGraph::changeFunction(Function *OldF, Function *NewF) { + iterator I = FunctionMap.find(OldF); + CallGraphNode *&New = FunctionMap[NewF]; + assert(I != FunctionMap.end() && I->second && !New && + "OldF didn't exist in CG or NewF already does!"); + New = I->second; + New->F = NewF; + FunctionMap.erase(I); +} + +// getOrInsertFunction - This method is identical to calling operator[], but +// it will insert a new CallGraphNode for the specified function if one does +// not already exist. +CallGraphNode *CallGraph::getOrInsertFunction(const Function *F) { + CallGraphNode *&CGN = FunctionMap[F]; + if (CGN) return CGN; + + assert((!F || F->getParent() == Mod) && "Function not in current module!"); + return CGN = new CallGraphNode(const_cast<Function*>(F)); +} + +void CallGraphNode::print(std::ostream &OS) const { + if (Function *F = getFunction()) + OS << "Call graph node for function: '" << F->getName() <<"'\n"; + else + OS << "Call graph node <<null function: 0x" << this << ">>:\n"; + + for (const_iterator I = begin(), E = end(); I != E; ++I) + if (Function *FI = I->second->getFunction()) + OS << " Calls function '" << FI->getName() <<"'\n"; + else + OS << " Calls external node\n"; + OS << "\n"; +} + +void CallGraphNode::dump() const { print(cerr); } + +/// removeCallEdgeFor - This method removes the edge in the node for the +/// specified call site. Note that this method takes linear time, so it +/// should be used sparingly. +void CallGraphNode::removeCallEdgeFor(CallSite CS) { + for (CalledFunctionsVector::iterator I = CalledFunctions.begin(); ; ++I) { + assert(I != CalledFunctions.end() && "Cannot find callsite to remove!"); + if (I->first == CS) { + CalledFunctions.erase(I); + return; + } + } +} + + +// removeAnyCallEdgeTo - This method removes any call edges from this node to +// the specified callee function. This takes more time to execute than +// removeCallEdgeTo, so it should not be used unless necessary. +void CallGraphNode::removeAnyCallEdgeTo(CallGraphNode *Callee) { + for (unsigned i = 0, e = CalledFunctions.size(); i != e; ++i) + if (CalledFunctions[i].second == Callee) { + CalledFunctions[i] = CalledFunctions.back(); + CalledFunctions.pop_back(); + --i; --e; + } +} + +/// removeOneAbstractEdgeTo - Remove one edge associated with a null callsite +/// from this node to the specified callee function. +void CallGraphNode::removeOneAbstractEdgeTo(CallGraphNode *Callee) { + for (CalledFunctionsVector::iterator I = CalledFunctions.begin(); ; ++I) { + assert(I != CalledFunctions.end() && "Cannot find callee to remove!"); + CallRecord &CR = *I; + if (CR.second == Callee && !CR.first.getInstruction()) { + CalledFunctions.erase(I); + return; + } + } +} + +/// replaceCallSite - Make the edge in the node for Old CallSite be for +/// New CallSite instead. Note that this method takes linear time, so it +/// should be used sparingly. +void CallGraphNode::replaceCallSite(CallSite Old, CallSite New) { + for (CalledFunctionsVector::iterator I = CalledFunctions.begin(); ; ++I) { + assert(I != CalledFunctions.end() && "Cannot find callsite to replace!"); + if (I->first == Old) { + I->first = New; + return; + } + } +} + +// Enuse that users of CallGraph.h also link with this file +DEFINING_FILE_FOR(CallGraph) diff --git a/lib/Analysis/IPA/CallGraphSCCPass.cpp b/lib/Analysis/IPA/CallGraphSCCPass.cpp new file mode 100644 index 0000000..3880d0a --- /dev/null +++ b/lib/Analysis/IPA/CallGraphSCCPass.cpp @@ -0,0 +1,207 @@ +//===- CallGraphSCCPass.cpp - Pass that operates BU on call graph ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the CallGraphSCCPass class, which is used for passes +// which are implemented as bottom-up traversals on the call graph. Because +// there may be cycles in the call graph, passes of this type operate on the +// call-graph in SCC order: that is, they process function bottom-up, except for +// recursive functions, which they process all at once. +// +//===----------------------------------------------------------------------===// + +#include "llvm/CallGraphSCCPass.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/PassManagers.h" +#include "llvm/Function.h" +using namespace llvm; + +//===----------------------------------------------------------------------===// +// CGPassManager +// +/// CGPassManager manages FPPassManagers and CalLGraphSCCPasses. + +namespace { + +class CGPassManager : public ModulePass, public PMDataManager { + +public: + static char ID; + explicit CGPassManager(int Depth) + : ModulePass(&ID), PMDataManager(Depth) { } + + /// run - Execute all of the passes scheduled for execution. Keep track of + /// whether any of the passes modifies the module, and if so, return true. + bool runOnModule(Module &M); + + bool doInitialization(CallGraph &CG); + bool doFinalization(CallGraph &CG); + + /// Pass Manager itself does not invalidate any analysis info. + void getAnalysisUsage(AnalysisUsage &Info) const { + // CGPassManager walks SCC and it needs CallGraph. + Info.addRequired<CallGraph>(); + Info.setPreservesAll(); + } + + virtual const char *getPassName() const { + return "CallGraph Pass Manager"; + } + + // Print passes managed by this manager + void dumpPassStructure(unsigned Offset) { + llvm::cerr << std::string(Offset*2, ' ') << "Call Graph SCC Pass Manager\n"; + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + Pass *P = getContainedPass(Index); + P->dumpPassStructure(Offset + 1); + dumpLastUses(P, Offset+1); + } + } + + Pass *getContainedPass(unsigned N) { + assert ( N < PassVector.size() && "Pass number out of range!"); + Pass *FP = static_cast<Pass *>(PassVector[N]); + return FP; + } + + virtual PassManagerType getPassManagerType() const { + return PMT_CallGraphPassManager; + } +}; + +} + +char CGPassManager::ID = 0; +/// run - Execute all of the passes scheduled for execution. Keep track of +/// whether any of the passes modifies the module, and if so, return true. +bool CGPassManager::runOnModule(Module &M) { + CallGraph &CG = getAnalysis<CallGraph>(); + bool Changed = doInitialization(CG); + + // Walk SCC + for (scc_iterator<CallGraph*> I = scc_begin(&CG), E = scc_end(&CG); + I != E; ++I) { + + // Run all passes on current SCC + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + Pass *P = getContainedPass(Index); + + dumpPassInfo(P, EXECUTION_MSG, ON_CG_MSG, ""); + dumpRequiredSet(P); + + initializeAnalysisImpl(P); + + StartPassTimer(P); + if (CallGraphSCCPass *CGSP = dynamic_cast<CallGraphSCCPass *>(P)) + Changed |= CGSP->runOnSCC(*I); // TODO : What if CG is changed ? + else { + FPPassManager *FPP = dynamic_cast<FPPassManager *>(P); + assert (FPP && "Invalid CGPassManager member"); + + // Run pass P on all functions current SCC + std::vector<CallGraphNode*> &SCC = *I; + for (unsigned i = 0, e = SCC.size(); i != e; ++i) { + Function *F = SCC[i]->getFunction(); + if (F) { + dumpPassInfo(P, EXECUTION_MSG, ON_FUNCTION_MSG, F->getNameStart()); + Changed |= FPP->runOnFunction(*F); + } + } + } + StopPassTimer(P); + + if (Changed) + dumpPassInfo(P, MODIFICATION_MSG, ON_CG_MSG, ""); + dumpPreservedSet(P); + + verifyPreservedAnalysis(P); + removeNotPreservedAnalysis(P); + recordAvailableAnalysis(P); + removeDeadPasses(P, "", ON_CG_MSG); + } + } + Changed |= doFinalization(CG); + return Changed; +} + +/// Initialize CG +bool CGPassManager::doInitialization(CallGraph &CG) { + bool Changed = false; + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + Pass *P = getContainedPass(Index); + if (CallGraphSCCPass *CGSP = dynamic_cast<CallGraphSCCPass *>(P)) { + Changed |= CGSP->doInitialization(CG); + } else { + FPPassManager *FP = dynamic_cast<FPPassManager *>(P); + assert (FP && "Invalid CGPassManager member"); + Changed |= FP->doInitialization(CG.getModule()); + } + } + return Changed; +} + +/// Finalize CG +bool CGPassManager::doFinalization(CallGraph &CG) { + bool Changed = false; + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + Pass *P = getContainedPass(Index); + if (CallGraphSCCPass *CGSP = dynamic_cast<CallGraphSCCPass *>(P)) { + Changed |= CGSP->doFinalization(CG); + } else { + FPPassManager *FP = dynamic_cast<FPPassManager *>(P); + assert (FP && "Invalid CGPassManager member"); + Changed |= FP->doFinalization(CG.getModule()); + } + } + return Changed; +} + +/// Assign pass manager to manage this pass. +void CallGraphSCCPass::assignPassManager(PMStack &PMS, + PassManagerType PreferredType) { + // Find CGPassManager + while (!PMS.empty() && + PMS.top()->getPassManagerType() > PMT_CallGraphPassManager) + PMS.pop(); + + assert (!PMS.empty() && "Unable to handle Call Graph Pass"); + CGPassManager *CGP = dynamic_cast<CGPassManager *>(PMS.top()); + + // Create new Call Graph SCC Pass Manager if it does not exist. + if (!CGP) { + + assert (!PMS.empty() && "Unable to create Call Graph Pass Manager"); + PMDataManager *PMD = PMS.top(); + + // [1] Create new Call Graph Pass Manager + CGP = new CGPassManager(PMD->getDepth() + 1); + + // [2] Set up new manager's top level manager + PMTopLevelManager *TPM = PMD->getTopLevelManager(); + TPM->addIndirectPassManager(CGP); + + // [3] Assign manager to manage this new manager. This may create + // and push new managers into PMS + Pass *P = dynamic_cast<Pass *>(CGP); + TPM->schedulePass(P); + + // [4] Push new manager into PMS + PMS.push(CGP); + } + + CGP->add(this); +} + +/// getAnalysisUsage - For this class, we declare that we require and preserve +/// the call graph. If the derived class implements this method, it should +/// always explicitly call the implementation here. +void CallGraphSCCPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<CallGraph>(); + AU.addPreserved<CallGraph>(); +} diff --git a/lib/Analysis/IPA/FindUsedTypes.cpp b/lib/Analysis/IPA/FindUsedTypes.cpp new file mode 100644 index 0000000..920ee37 --- /dev/null +++ b/lib/Analysis/IPA/FindUsedTypes.cpp @@ -0,0 +1,104 @@ +//===- FindUsedTypes.cpp - Find all Types used by a module ----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass is used to seek out all of the types in use by the program. Note +// that this analysis explicitly does not include types only used by the symbol +// table. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/FindUsedTypes.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Module.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Support/InstIterator.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +char FindUsedTypes::ID = 0; +static RegisterPass<FindUsedTypes> +X("print-used-types", "Find Used Types", false, true); + +// IncorporateType - Incorporate one type and all of its subtypes into the +// collection of used types. +// +void FindUsedTypes::IncorporateType(const Type *Ty) { + // If ty doesn't already exist in the used types map, add it now, otherwise + // return. + if (!UsedTypes.insert(Ty).second) return; // Already contain Ty. + + // Make sure to add any types this type references now. + // + for (Type::subtype_iterator I = Ty->subtype_begin(), E = Ty->subtype_end(); + I != E; ++I) + IncorporateType(*I); +} + +void FindUsedTypes::IncorporateValue(const Value *V) { + IncorporateType(V->getType()); + + // If this is a constant, it could be using other types... + if (const Constant *C = dyn_cast<Constant>(V)) { + if (!isa<GlobalValue>(C)) + for (User::const_op_iterator OI = C->op_begin(), OE = C->op_end(); + OI != OE; ++OI) + IncorporateValue(*OI); + } +} + + +// run - This incorporates all types used by the specified module +// +bool FindUsedTypes::runOnModule(Module &m) { + UsedTypes.clear(); // reset if run multiple times... + + // Loop over global variables, incorporating their types + for (Module::const_global_iterator I = m.global_begin(), E = m.global_end(); + I != E; ++I) { + IncorporateType(I->getType()); + if (I->hasInitializer()) + IncorporateValue(I->getInitializer()); + } + + for (Module::iterator MI = m.begin(), ME = m.end(); MI != ME; ++MI) { + IncorporateType(MI->getType()); + const Function &F = *MI; + + // Loop over all of the instructions in the function, adding their return + // type as well as the types of their operands. + // + for (const_inst_iterator II = inst_begin(F), IE = inst_end(F); + II != IE; ++II) { + const Instruction &I = *II; + + IncorporateType(I.getType()); // Incorporate the type of the instruction + for (User::const_op_iterator OI = I.op_begin(), OE = I.op_end(); + OI != OE; ++OI) + IncorporateValue(*OI); // Insert inst operand types as well + } + } + + return false; +} + +// Print the types found in the module. If the optional Module parameter is +// passed in, then the types are printed symbolically if possible, using the +// symbol table from the module. +// +void FindUsedTypes::print(std::ostream &OS, const Module *M) const { + raw_os_ostream RO(OS); + RO << "Types in use by this module:\n"; + for (std::set<const Type *>::const_iterator I = UsedTypes.begin(), + E = UsedTypes.end(); I != E; ++I) { + RO << " "; + WriteTypeSymbolic(RO, *I, M); + RO << '\n'; + } +} diff --git a/lib/Analysis/IPA/GlobalsModRef.cpp b/lib/Analysis/IPA/GlobalsModRef.cpp new file mode 100644 index 0000000..2e9884a --- /dev/null +++ b/lib/Analysis/IPA/GlobalsModRef.cpp @@ -0,0 +1,567 @@ +//===- GlobalsModRef.cpp - Simple Mod/Ref Analysis for Globals ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This simple pass provides alias and mod/ref information for global values +// that do not have their address taken, and keeps track of whether functions +// read or write memory (are "pure"). For this simple (but very common) case, +// we can provide pretty accurate and useful information. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "globalsmodref-aa" +#include "llvm/Analysis/Passes.h" +#include "llvm/Module.h" +#include "llvm/Pass.h" +#include "llvm/Instructions.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InstIterator.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/SCCIterator.h" +#include <set> +using namespace llvm; + +STATISTIC(NumNonAddrTakenGlobalVars, + "Number of global vars without address taken"); +STATISTIC(NumNonAddrTakenFunctions,"Number of functions without address taken"); +STATISTIC(NumNoMemFunctions, "Number of functions that do not access memory"); +STATISTIC(NumReadMemFunctions, "Number of functions that only read memory"); +STATISTIC(NumIndirectGlobalVars, "Number of indirect global objects"); + +namespace { + /// FunctionRecord - One instance of this structure is stored for every + /// function in the program. Later, the entries for these functions are + /// removed if the function is found to call an external function (in which + /// case we know nothing about it. + struct VISIBILITY_HIDDEN FunctionRecord { + /// GlobalInfo - Maintain mod/ref info for all of the globals without + /// addresses taken that are read or written (transitively) by this + /// function. + std::map<GlobalValue*, unsigned> GlobalInfo; + + /// MayReadAnyGlobal - May read global variables, but it is not known which. + bool MayReadAnyGlobal; + + unsigned getInfoForGlobal(GlobalValue *GV) const { + unsigned Effect = MayReadAnyGlobal ? AliasAnalysis::Ref : 0; + std::map<GlobalValue*, unsigned>::const_iterator I = GlobalInfo.find(GV); + if (I != GlobalInfo.end()) + Effect |= I->second; + return Effect; + } + + /// FunctionEffect - Capture whether or not this function reads or writes to + /// ANY memory. If not, we can do a lot of aggressive analysis on it. + unsigned FunctionEffect; + + FunctionRecord() : MayReadAnyGlobal (false), FunctionEffect(0) {} + }; + + /// GlobalsModRef - The actual analysis pass. + class VISIBILITY_HIDDEN GlobalsModRef + : public ModulePass, public AliasAnalysis { + /// NonAddressTakenGlobals - The globals that do not have their addresses + /// taken. + std::set<GlobalValue*> NonAddressTakenGlobals; + + /// IndirectGlobals - The memory pointed to by this global is known to be + /// 'owned' by the global. + std::set<GlobalValue*> IndirectGlobals; + + /// AllocsForIndirectGlobals - If an instruction allocates memory for an + /// indirect global, this map indicates which one. + std::map<Value*, GlobalValue*> AllocsForIndirectGlobals; + + /// FunctionInfo - For each function, keep track of what globals are + /// modified or read. + std::map<Function*, FunctionRecord> FunctionInfo; + + public: + static char ID; + GlobalsModRef() : ModulePass(&ID) {} + + bool runOnModule(Module &M) { + InitializeAliasAnalysis(this); // set up super class + AnalyzeGlobals(M); // find non-addr taken globals + AnalyzeCallGraph(getAnalysis<CallGraph>(), M); // Propagate on CG + return false; + } + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AliasAnalysis::getAnalysisUsage(AU); + AU.addRequired<CallGraph>(); + AU.setPreservesAll(); // Does not transform code + } + + //------------------------------------------------ + // Implement the AliasAnalysis API + // + AliasResult alias(const Value *V1, unsigned V1Size, + const Value *V2, unsigned V2Size); + ModRefResult getModRefInfo(CallSite CS, Value *P, unsigned Size); + ModRefResult getModRefInfo(CallSite CS1, CallSite CS2) { + return AliasAnalysis::getModRefInfo(CS1,CS2); + } + bool hasNoModRefInfoForCalls() const { return false; } + + /// getModRefBehavior - Return the behavior of the specified function if + /// called from the specified call site. The call site may be null in which + /// case the most generic behavior of this function should be returned. + ModRefBehavior getModRefBehavior(Function *F, + std::vector<PointerAccessInfo> *Info) { + if (FunctionRecord *FR = getFunctionInfo(F)) { + if (FR->FunctionEffect == 0) + return DoesNotAccessMemory; + else if ((FR->FunctionEffect & Mod) == 0) + return OnlyReadsMemory; + } + return AliasAnalysis::getModRefBehavior(F, Info); + } + + /// getModRefBehavior - Return the behavior of the specified function if + /// called from the specified call site. The call site may be null in which + /// case the most generic behavior of this function should be returned. + ModRefBehavior getModRefBehavior(CallSite CS, + std::vector<PointerAccessInfo> *Info) { + Function* F = CS.getCalledFunction(); + if (!F) return AliasAnalysis::getModRefBehavior(CS, Info); + if (FunctionRecord *FR = getFunctionInfo(F)) { + if (FR->FunctionEffect == 0) + return DoesNotAccessMemory; + else if ((FR->FunctionEffect & Mod) == 0) + return OnlyReadsMemory; + } + return AliasAnalysis::getModRefBehavior(CS, Info); + } + + virtual void deleteValue(Value *V); + virtual void copyValue(Value *From, Value *To); + + private: + /// getFunctionInfo - Return the function info for the function, or null if + /// we don't have anything useful to say about it. + FunctionRecord *getFunctionInfo(Function *F) { + std::map<Function*, FunctionRecord>::iterator I = FunctionInfo.find(F); + if (I != FunctionInfo.end()) + return &I->second; + return 0; + } + + void AnalyzeGlobals(Module &M); + void AnalyzeCallGraph(CallGraph &CG, Module &M); + bool AnalyzeUsesOfPointer(Value *V, std::vector<Function*> &Readers, + std::vector<Function*> &Writers, + GlobalValue *OkayStoreDest = 0); + bool AnalyzeIndirectGlobalMemory(GlobalValue *GV); + }; +} + +char GlobalsModRef::ID = 0; +static RegisterPass<GlobalsModRef> +X("globalsmodref-aa", "Simple mod/ref analysis for globals", false, true); +static RegisterAnalysisGroup<AliasAnalysis> Y(X); + +Pass *llvm::createGlobalsModRefPass() { return new GlobalsModRef(); } + +/// AnalyzeGlobals - Scan through the users of all of the internal +/// GlobalValue's in the program. If none of them have their "address taken" +/// (really, their address passed to something nontrivial), record this fact, +/// and record the functions that they are used directly in. +void GlobalsModRef::AnalyzeGlobals(Module &M) { + std::vector<Function*> Readers, Writers; + for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) + if (I->hasLocalLinkage()) { + if (!AnalyzeUsesOfPointer(I, Readers, Writers)) { + // Remember that we are tracking this global. + NonAddressTakenGlobals.insert(I); + ++NumNonAddrTakenFunctions; + } + Readers.clear(); Writers.clear(); + } + + for (Module::global_iterator I = M.global_begin(), E = M.global_end(); + I != E; ++I) + if (I->hasLocalLinkage()) { + if (!AnalyzeUsesOfPointer(I, Readers, Writers)) { + // Remember that we are tracking this global, and the mod/ref fns + NonAddressTakenGlobals.insert(I); + + for (unsigned i = 0, e = Readers.size(); i != e; ++i) + FunctionInfo[Readers[i]].GlobalInfo[I] |= Ref; + + if (!I->isConstant()) // No need to keep track of writers to constants + for (unsigned i = 0, e = Writers.size(); i != e; ++i) + FunctionInfo[Writers[i]].GlobalInfo[I] |= Mod; + ++NumNonAddrTakenGlobalVars; + + // If this global holds a pointer type, see if it is an indirect global. + if (isa<PointerType>(I->getType()->getElementType()) && + AnalyzeIndirectGlobalMemory(I)) + ++NumIndirectGlobalVars; + } + Readers.clear(); Writers.clear(); + } +} + +/// AnalyzeUsesOfPointer - Look at all of the users of the specified pointer. +/// If this is used by anything complex (i.e., the address escapes), return +/// true. Also, while we are at it, keep track of those functions that read and +/// write to the value. +/// +/// If OkayStoreDest is non-null, stores into this global are allowed. +bool GlobalsModRef::AnalyzeUsesOfPointer(Value *V, + std::vector<Function*> &Readers, + std::vector<Function*> &Writers, + GlobalValue *OkayStoreDest) { + if (!isa<PointerType>(V->getType())) return true; + + for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI != E; ++UI) + if (LoadInst *LI = dyn_cast<LoadInst>(*UI)) { + Readers.push_back(LI->getParent()->getParent()); + } else if (StoreInst *SI = dyn_cast<StoreInst>(*UI)) { + if (V == SI->getOperand(1)) { + Writers.push_back(SI->getParent()->getParent()); + } else if (SI->getOperand(1) != OkayStoreDest) { + return true; // Storing the pointer + } + } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(*UI)) { + if (AnalyzeUsesOfPointer(GEP, Readers, Writers)) return true; + } else if (CallInst *CI = dyn_cast<CallInst>(*UI)) { + // Make sure that this is just the function being called, not that it is + // passing into the function. + for (unsigned i = 1, e = CI->getNumOperands(); i != e; ++i) + if (CI->getOperand(i) == V) return true; + } else if (InvokeInst *II = dyn_cast<InvokeInst>(*UI)) { + // Make sure that this is just the function being called, not that it is + // passing into the function. + for (unsigned i = 3, e = II->getNumOperands(); i != e; ++i) + if (II->getOperand(i) == V) return true; + } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(*UI)) { + if (CE->getOpcode() == Instruction::GetElementPtr || + CE->getOpcode() == Instruction::BitCast) { + if (AnalyzeUsesOfPointer(CE, Readers, Writers)) + return true; + } else { + return true; + } + } else if (ICmpInst *ICI = dyn_cast<ICmpInst>(*UI)) { + if (!isa<ConstantPointerNull>(ICI->getOperand(1))) + return true; // Allow comparison against null. + } else if (FreeInst *F = dyn_cast<FreeInst>(*UI)) { + Writers.push_back(F->getParent()->getParent()); + } else { + return true; + } + return false; +} + +/// AnalyzeIndirectGlobalMemory - We found an non-address-taken global variable +/// which holds a pointer type. See if the global always points to non-aliased +/// heap memory: that is, all initializers of the globals are allocations, and +/// those allocations have no use other than initialization of the global. +/// Further, all loads out of GV must directly use the memory, not store the +/// pointer somewhere. If this is true, we consider the memory pointed to by +/// GV to be owned by GV and can disambiguate other pointers from it. +bool GlobalsModRef::AnalyzeIndirectGlobalMemory(GlobalValue *GV) { + // Keep track of values related to the allocation of the memory, f.e. the + // value produced by the malloc call and any casts. + std::vector<Value*> AllocRelatedValues; + + // Walk the user list of the global. If we find anything other than a direct + // load or store, bail out. + for (Value::use_iterator I = GV->use_begin(), E = GV->use_end(); I != E; ++I){ + if (LoadInst *LI = dyn_cast<LoadInst>(*I)) { + // The pointer loaded from the global can only be used in simple ways: + // we allow addressing of it and loading storing to it. We do *not* allow + // storing the loaded pointer somewhere else or passing to a function. + std::vector<Function*> ReadersWriters; + if (AnalyzeUsesOfPointer(LI, ReadersWriters, ReadersWriters)) + return false; // Loaded pointer escapes. + // TODO: Could try some IP mod/ref of the loaded pointer. + } else if (StoreInst *SI = dyn_cast<StoreInst>(*I)) { + // Storing the global itself. + if (SI->getOperand(0) == GV) return false; + + // If storing the null pointer, ignore it. + if (isa<ConstantPointerNull>(SI->getOperand(0))) + continue; + + // Check the value being stored. + Value *Ptr = SI->getOperand(0)->getUnderlyingObject(); + + if (isa<MallocInst>(Ptr)) { + // Okay, easy case. + } else if (CallInst *CI = dyn_cast<CallInst>(Ptr)) { + Function *F = CI->getCalledFunction(); + if (!F || !F->isDeclaration()) return false; // Too hard to analyze. + if (F->getName() != "calloc") return false; // Not calloc. + } else { + return false; // Too hard to analyze. + } + + // Analyze all uses of the allocation. If any of them are used in a + // non-simple way (e.g. stored to another global) bail out. + std::vector<Function*> ReadersWriters; + if (AnalyzeUsesOfPointer(Ptr, ReadersWriters, ReadersWriters, GV)) + return false; // Loaded pointer escapes. + + // Remember that this allocation is related to the indirect global. + AllocRelatedValues.push_back(Ptr); + } else { + // Something complex, bail out. + return false; + } + } + + // Okay, this is an indirect global. Remember all of the allocations for + // this global in AllocsForIndirectGlobals. + while (!AllocRelatedValues.empty()) { + AllocsForIndirectGlobals[AllocRelatedValues.back()] = GV; + AllocRelatedValues.pop_back(); + } + IndirectGlobals.insert(GV); + return true; +} + +/// AnalyzeCallGraph - At this point, we know the functions where globals are +/// immediately stored to and read from. Propagate this information up the call +/// graph to all callers and compute the mod/ref info for all memory for each +/// function. +void GlobalsModRef::AnalyzeCallGraph(CallGraph &CG, Module &M) { + // We do a bottom-up SCC traversal of the call graph. In other words, we + // visit all callees before callers (leaf-first). + for (scc_iterator<CallGraph*> I = scc_begin(&CG), E = scc_end(&CG); I != E; + ++I) { + std::vector<CallGraphNode *> &SCC = *I; + assert(!SCC.empty() && "SCC with no functions?"); + + if (!SCC[0]->getFunction()) { + // Calls externally - can't say anything useful. Remove any existing + // function records (may have been created when scanning globals). + for (unsigned i = 0, e = SCC.size(); i != e; ++i) + FunctionInfo.erase(SCC[i]->getFunction()); + continue; + } + + FunctionRecord &FR = FunctionInfo[SCC[0]->getFunction()]; + + bool KnowNothing = false; + unsigned FunctionEffect = 0; + + // Collect the mod/ref properties due to called functions. We only compute + // one mod-ref set. + for (unsigned i = 0, e = SCC.size(); i != e && !KnowNothing; ++i) { + Function *F = SCC[i]->getFunction(); + if (!F) { + KnowNothing = true; + break; + } + + if (F->isDeclaration()) { + // Try to get mod/ref behaviour from function attributes. + if (F->doesNotAccessMemory()) { + // Can't do better than that! + } else if (F->onlyReadsMemory()) { + FunctionEffect |= Ref; + if (!F->isIntrinsic()) + // This function might call back into the module and read a global - + // consider every global as possibly being read by this function. + FR.MayReadAnyGlobal = true; + } else { + FunctionEffect |= ModRef; + // Can't say anything useful unless it's an intrinsic - they don't + // read or write global variables of the kind considered here. + KnowNothing = !F->isIntrinsic(); + } + continue; + } + + for (CallGraphNode::iterator CI = SCC[i]->begin(), E = SCC[i]->end(); + CI != E && !KnowNothing; ++CI) + if (Function *Callee = CI->second->getFunction()) { + if (FunctionRecord *CalleeFR = getFunctionInfo(Callee)) { + // Propagate function effect up. + FunctionEffect |= CalleeFR->FunctionEffect; + + // Incorporate callee's effects on globals into our info. + for (std::map<GlobalValue*, unsigned>::iterator GI = + CalleeFR->GlobalInfo.begin(), E = CalleeFR->GlobalInfo.end(); + GI != E; ++GI) + FR.GlobalInfo[GI->first] |= GI->second; + FR.MayReadAnyGlobal |= CalleeFR->MayReadAnyGlobal; + } else { + // Can't say anything about it. However, if it is inside our SCC, + // then nothing needs to be done. + CallGraphNode *CalleeNode = CG[Callee]; + if (std::find(SCC.begin(), SCC.end(), CalleeNode) == SCC.end()) + KnowNothing = true; + } + } else { + KnowNothing = true; + } + } + + // If we can't say anything useful about this SCC, remove all SCC functions + // from the FunctionInfo map. + if (KnowNothing) { + for (unsigned i = 0, e = SCC.size(); i != e; ++i) + FunctionInfo.erase(SCC[i]->getFunction()); + continue; + } + + // Scan the function bodies for explicit loads or stores. + for (unsigned i = 0, e = SCC.size(); i != e && FunctionEffect != ModRef;++i) + for (inst_iterator II = inst_begin(SCC[i]->getFunction()), + E = inst_end(SCC[i]->getFunction()); + II != E && FunctionEffect != ModRef; ++II) + if (isa<LoadInst>(*II)) { + FunctionEffect |= Ref; + if (cast<LoadInst>(*II).isVolatile()) + // Volatile loads may have side-effects, so mark them as writing + // memory (for example, a flag inside the processor). + FunctionEffect |= Mod; + } else if (isa<StoreInst>(*II)) { + FunctionEffect |= Mod; + if (cast<StoreInst>(*II).isVolatile()) + // Treat volatile stores as reading memory somewhere. + FunctionEffect |= Ref; + } else if (isa<MallocInst>(*II) || isa<FreeInst>(*II)) { + FunctionEffect |= ModRef; + } + + if ((FunctionEffect & Mod) == 0) + ++NumReadMemFunctions; + if (FunctionEffect == 0) + ++NumNoMemFunctions; + FR.FunctionEffect = FunctionEffect; + + // Finally, now that we know the full effect on this SCC, clone the + // information to each function in the SCC. + for (unsigned i = 1, e = SCC.size(); i != e; ++i) + FunctionInfo[SCC[i]->getFunction()] = FR; + } +} + + + +/// alias - If one of the pointers is to a global that we are tracking, and the +/// other is some random pointer, we know there cannot be an alias, because the +/// address of the global isn't taken. +AliasAnalysis::AliasResult +GlobalsModRef::alias(const Value *V1, unsigned V1Size, + const Value *V2, unsigned V2Size) { + // Get the base object these pointers point to. + Value *UV1 = const_cast<Value*>(V1->getUnderlyingObject()); + Value *UV2 = const_cast<Value*>(V2->getUnderlyingObject()); + + // If either of the underlying values is a global, they may be non-addr-taken + // globals, which we can answer queries about. + GlobalValue *GV1 = dyn_cast<GlobalValue>(UV1); + GlobalValue *GV2 = dyn_cast<GlobalValue>(UV2); + if (GV1 || GV2) { + // If the global's address is taken, pretend we don't know it's a pointer to + // the global. + if (GV1 && !NonAddressTakenGlobals.count(GV1)) GV1 = 0; + if (GV2 && !NonAddressTakenGlobals.count(GV2)) GV2 = 0; + + // If the the two pointers are derived from two different non-addr-taken + // globals, or if one is and the other isn't, we know these can't alias. + if ((GV1 || GV2) && GV1 != GV2) + return NoAlias; + + // Otherwise if they are both derived from the same addr-taken global, we + // can't know the two accesses don't overlap. + } + + // These pointers may be based on the memory owned by an indirect global. If + // so, we may be able to handle this. First check to see if the base pointer + // is a direct load from an indirect global. + GV1 = GV2 = 0; + if (LoadInst *LI = dyn_cast<LoadInst>(UV1)) + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(LI->getOperand(0))) + if (IndirectGlobals.count(GV)) + GV1 = GV; + if (LoadInst *LI = dyn_cast<LoadInst>(UV2)) + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(LI->getOperand(0))) + if (IndirectGlobals.count(GV)) + GV2 = GV; + + // These pointers may also be from an allocation for the indirect global. If + // so, also handle them. + if (AllocsForIndirectGlobals.count(UV1)) + GV1 = AllocsForIndirectGlobals[UV1]; + if (AllocsForIndirectGlobals.count(UV2)) + GV2 = AllocsForIndirectGlobals[UV2]; + + // Now that we know whether the two pointers are related to indirect globals, + // use this to disambiguate the pointers. If either pointer is based on an + // indirect global and if they are not both based on the same indirect global, + // they cannot alias. + if ((GV1 || GV2) && GV1 != GV2) + return NoAlias; + + return AliasAnalysis::alias(V1, V1Size, V2, V2Size); +} + +AliasAnalysis::ModRefResult +GlobalsModRef::getModRefInfo(CallSite CS, Value *P, unsigned Size) { + unsigned Known = ModRef; + + // If we are asking for mod/ref info of a direct call with a pointer to a + // global we are tracking, return information if we have it. + if (GlobalValue *GV = dyn_cast<GlobalValue>(P->getUnderlyingObject())) + if (GV->hasLocalLinkage()) + if (Function *F = CS.getCalledFunction()) + if (NonAddressTakenGlobals.count(GV)) + if (FunctionRecord *FR = getFunctionInfo(F)) + Known = FR->getInfoForGlobal(GV); + + if (Known == NoModRef) + return NoModRef; // No need to query other mod/ref analyses + return ModRefResult(Known & AliasAnalysis::getModRefInfo(CS, P, Size)); +} + + +//===----------------------------------------------------------------------===// +// Methods to update the analysis as a result of the client transformation. +// +void GlobalsModRef::deleteValue(Value *V) { + if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) { + if (NonAddressTakenGlobals.erase(GV)) { + // This global might be an indirect global. If so, remove it and remove + // any AllocRelatedValues for it. + if (IndirectGlobals.erase(GV)) { + // Remove any entries in AllocsForIndirectGlobals for this global. + for (std::map<Value*, GlobalValue*>::iterator + I = AllocsForIndirectGlobals.begin(), + E = AllocsForIndirectGlobals.end(); I != E; ) { + if (I->second == GV) { + AllocsForIndirectGlobals.erase(I++); + } else { + ++I; + } + } + } + } + } + + // Otherwise, if this is an allocation related to an indirect global, remove + // it. + AllocsForIndirectGlobals.erase(V); + + AliasAnalysis::deleteValue(V); +} + +void GlobalsModRef::copyValue(Value *From, Value *To) { + AliasAnalysis::copyValue(From, To); +} diff --git a/lib/Analysis/IPA/Makefile b/lib/Analysis/IPA/Makefile new file mode 100644 index 0000000..adacb16 --- /dev/null +++ b/lib/Analysis/IPA/Makefile @@ -0,0 +1,14 @@ +##===- lib/Analysis/IPA/Makefile ---------------------------*- Makefile -*-===## +# +# The LLVM Compiler Infrastructure +# +# This file is distributed under the University of Illinois Open Source +# License. See LICENSE.TXT for details. +# +##===----------------------------------------------------------------------===## + +LEVEL = ../../.. +LIBRARYNAME = LLVMipa +BUILD_ARCHIVE = 1 +include $(LEVEL)/Makefile.common + diff --git a/lib/Analysis/IVUsers.cpp b/lib/Analysis/IVUsers.cpp new file mode 100644 index 0000000..7af9130 --- /dev/null +++ b/lib/Analysis/IVUsers.cpp @@ -0,0 +1,391 @@ +//===- IVUsers.cpp - Induction Variable Users -------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements bookkeeping for "interesting" users of expressions +// computed from induction variables. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "iv-users" +#include "llvm/Analysis/IVUsers.h" +#include "llvm/Constants.h" +#include "llvm/Instructions.h" +#include "llvm/Type.h" +#include "llvm/DerivedTypes.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +using namespace llvm; + +char IVUsers::ID = 0; +static RegisterPass<IVUsers> +X("iv-users", "Induction Variable Users", false, true); + +Pass *llvm::createIVUsersPass() { + return new IVUsers(); +} + +/// containsAddRecFromDifferentLoop - Determine whether expression S involves a +/// subexpression that is an AddRec from a loop other than L. An outer loop +/// of L is OK, but not an inner loop nor a disjoint loop. +static bool containsAddRecFromDifferentLoop(SCEVHandle S, Loop *L) { + // This is very common, put it first. + if (isa<SCEVConstant>(S)) + return false; + if (const SCEVCommutativeExpr *AE = dyn_cast<SCEVCommutativeExpr>(S)) { + for (unsigned int i=0; i< AE->getNumOperands(); i++) + if (containsAddRecFromDifferentLoop(AE->getOperand(i), L)) + return true; + return false; + } + if (const SCEVAddRecExpr *AE = dyn_cast<SCEVAddRecExpr>(S)) { + if (const Loop *newLoop = AE->getLoop()) { + if (newLoop == L) + return false; + // if newLoop is an outer loop of L, this is OK. + if (!LoopInfoBase<BasicBlock>::isNotAlreadyContainedIn(L, newLoop)) + return false; + } + return true; + } + if (const SCEVUDivExpr *DE = dyn_cast<SCEVUDivExpr>(S)) + return containsAddRecFromDifferentLoop(DE->getLHS(), L) || + containsAddRecFromDifferentLoop(DE->getRHS(), L); +#if 0 + // SCEVSDivExpr has been backed out temporarily, but will be back; we'll + // need this when it is. + if (const SCEVSDivExpr *DE = dyn_cast<SCEVSDivExpr>(S)) + return containsAddRecFromDifferentLoop(DE->getLHS(), L) || + containsAddRecFromDifferentLoop(DE->getRHS(), L); +#endif + if (const SCEVCastExpr *CE = dyn_cast<SCEVCastExpr>(S)) + return containsAddRecFromDifferentLoop(CE->getOperand(), L); + return false; +} + +/// getSCEVStartAndStride - Compute the start and stride of this expression, +/// returning false if the expression is not a start/stride pair, or true if it +/// is. The stride must be a loop invariant expression, but the start may be +/// a mix of loop invariant and loop variant expressions. The start cannot, +/// however, contain an AddRec from a different loop, unless that loop is an +/// outer loop of the current loop. +static bool getSCEVStartAndStride(const SCEVHandle &SH, Loop *L, Loop *UseLoop, + SCEVHandle &Start, SCEVHandle &Stride, + bool &isSigned, + ScalarEvolution *SE, DominatorTree *DT) { + SCEVHandle TheAddRec = Start; // Initialize to zero. + bool isSExt = false; + bool isZExt = false; + + // If the outer level is an AddExpr, the operands are all start values except + // for a nested AddRecExpr. + if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(SH)) { + for (unsigned i = 0, e = AE->getNumOperands(); i != e; ++i) + if (const SCEVAddRecExpr *AddRec = + dyn_cast<SCEVAddRecExpr>(AE->getOperand(i))) { + if (AddRec->getLoop() == L) + TheAddRec = SE->getAddExpr(AddRec, TheAddRec); + else + return false; // Nested IV of some sort? + } else { + Start = SE->getAddExpr(Start, AE->getOperand(i)); + } + + } else if (const SCEVZeroExtendExpr *Z = dyn_cast<SCEVZeroExtendExpr>(SH)) { + TheAddRec = Z->getOperand(); + isZExt = true; + } else if (const SCEVSignExtendExpr *S = dyn_cast<SCEVSignExtendExpr>(SH)) { + TheAddRec = S->getOperand(); + isSExt = true; + } else if (isa<SCEVAddRecExpr>(SH)) { + TheAddRec = SH; + } else { + return false; // not analyzable. + } + + const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(TheAddRec); + if (!AddRec || AddRec->getLoop() != L) return false; + + // Use getSCEVAtScope to attempt to simplify other loops out of + // the picture. + SCEVHandle AddRecStart = AddRec->getStart(); + SCEVHandle BetterAddRecStart = SE->getSCEVAtScope(AddRecStart, UseLoop); + if (!isa<SCEVCouldNotCompute>(BetterAddRecStart)) + AddRecStart = BetterAddRecStart; + + // FIXME: If Start contains an SCEVAddRecExpr from a different loop, other + // than an outer loop of the current loop, reject it. LSR has no concept of + // operating on more than one loop at a time so don't confuse it with such + // expressions. + if (containsAddRecFromDifferentLoop(AddRecStart, L)) + return false; + + if (isSExt || isZExt) + Start = SE->getTruncateExpr(Start, AddRec->getType()); + + Start = SE->getAddExpr(Start, AddRecStart); + + if (!isa<SCEVConstant>(AddRec->getStepRecurrence(*SE))) { + // If stride is an instruction, make sure it dominates the loop preheader. + // Otherwise we could end up with a use before def situation. + BasicBlock *Preheader = L->getLoopPreheader(); + if (!AddRec->getStepRecurrence(*SE)->dominates(Preheader, DT)) + return false; + + DOUT << "[" << L->getHeader()->getName() + << "] Variable stride: " << *AddRec << "\n"; + } + + Stride = AddRec->getStepRecurrence(*SE); + isSigned = isSExt; + return true; +} + +/// IVUseShouldUsePostIncValue - We have discovered a "User" of an IV expression +/// and now we need to decide whether the user should use the preinc or post-inc +/// value. If this user should use the post-inc version of the IV, return true. +/// +/// Choosing wrong here can break dominance properties (if we choose to use the +/// post-inc value when we cannot) or it can end up adding extra live-ranges to +/// the loop, resulting in reg-reg copies (if we use the pre-inc value when we +/// should use the post-inc value). +static bool IVUseShouldUsePostIncValue(Instruction *User, Instruction *IV, + Loop *L, LoopInfo *LI, DominatorTree *DT, + Pass *P) { + // If the user is in the loop, use the preinc value. + if (L->contains(User->getParent())) return false; + + BasicBlock *LatchBlock = L->getLoopLatch(); + + // Ok, the user is outside of the loop. If it is dominated by the latch + // block, use the post-inc value. + if (DT->dominates(LatchBlock, User->getParent())) + return true; + + // There is one case we have to be careful of: PHI nodes. These little guys + // can live in blocks that are not dominated by the latch block, but (since + // their uses occur in the predecessor block, not the block the PHI lives in) + // should still use the post-inc value. Check for this case now. + PHINode *PN = dyn_cast<PHINode>(User); + if (!PN) return false; // not a phi, not dominated by latch block. + + // Look at all of the uses of IV by the PHI node. If any use corresponds to + // a block that is not dominated by the latch block, give up and use the + // preincremented value. + unsigned NumUses = 0; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == IV) { + ++NumUses; + if (!DT->dominates(LatchBlock, PN->getIncomingBlock(i))) + return false; + } + + // Okay, all uses of IV by PN are in predecessor blocks that really are + // dominated by the latch block. Use the post-incremented value. + return true; +} + +/// AddUsersIfInteresting - Inspect the specified instruction. If it is a +/// reducible SCEV, recursively add its users to the IVUsesByStride set and +/// return true. Otherwise, return false. +bool IVUsers::AddUsersIfInteresting(Instruction *I) { + if (!SE->isSCEVable(I->getType())) + return false; // Void and FP expressions cannot be reduced. + + // LSR is not APInt clean, do not touch integers bigger than 64-bits. + if (SE->getTypeSizeInBits(I->getType()) > 64) + return false; + + if (!Processed.insert(I)) + return true; // Instruction already handled. + + // Get the symbolic expression for this instruction. + SCEVHandle ISE = SE->getSCEV(I); + if (isa<SCEVCouldNotCompute>(ISE)) return false; + + // Get the start and stride for this expression. + Loop *UseLoop = LI->getLoopFor(I->getParent()); + SCEVHandle Start = SE->getIntegerSCEV(0, ISE->getType()); + SCEVHandle Stride = Start; + bool isSigned = false; // Arbitrary initial value - pacifies compiler. + + if (!getSCEVStartAndStride(ISE, L, UseLoop, Start, Stride, isSigned, SE, DT)) + return false; // Non-reducible symbolic expression, bail out. + + SmallPtrSet<Instruction *, 4> UniqueUsers; + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); + UI != E; ++UI) { + Instruction *User = cast<Instruction>(*UI); + if (!UniqueUsers.insert(User)) + continue; + + // Do not infinitely recurse on PHI nodes. + if (isa<PHINode>(User) && Processed.count(User)) + continue; + + // Descend recursively, but not into PHI nodes outside the current loop. + // It's important to see the entire expression outside the loop to get + // choices that depend on addressing mode use right, although we won't + // consider references ouside the loop in all cases. + // If User is already in Processed, we don't want to recurse into it again, + // but do want to record a second reference in the same instruction. + bool AddUserToIVUsers = false; + if (LI->getLoopFor(User->getParent()) != L) { + if (isa<PHINode>(User) || Processed.count(User) || + !AddUsersIfInteresting(User)) { + DOUT << "FOUND USER in other loop: " << *User + << " OF SCEV: " << *ISE << "\n"; + AddUserToIVUsers = true; + } + } else if (Processed.count(User) || + !AddUsersIfInteresting(User)) { + DOUT << "FOUND USER: " << *User + << " OF SCEV: " << *ISE << "\n"; + AddUserToIVUsers = true; + } + + if (AddUserToIVUsers) { + IVUsersOfOneStride *StrideUses = IVUsesByStride[Stride]; + if (!StrideUses) { // First occurrence of this stride? + StrideOrder.push_back(Stride); + StrideUses = new IVUsersOfOneStride(Stride); + IVUses.push_back(StrideUses); + IVUsesByStride[Stride] = StrideUses; + } + + // Okay, we found a user that we cannot reduce. Analyze the instruction + // and decide what to do with it. If we are a use inside of the loop, use + // the value before incrementation, otherwise use it after incrementation. + if (IVUseShouldUsePostIncValue(User, I, L, LI, DT, this)) { + // The value used will be incremented by the stride more than we are + // expecting, so subtract this off. + SCEVHandle NewStart = SE->getMinusSCEV(Start, Stride); + StrideUses->addUser(NewStart, User, I, isSigned); + StrideUses->Users.back().setIsUseOfPostIncrementedValue(true); + DOUT << " USING POSTINC SCEV, START=" << *NewStart<< "\n"; + } else { + StrideUses->addUser(Start, User, I, isSigned); + } + } + } + return true; +} + +IVUsers::IVUsers() + : LoopPass(&ID) { +} + +void IVUsers::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<LoopInfo>(); + AU.addRequired<DominatorTree>(); + AU.addRequired<ScalarEvolution>(); + AU.setPreservesAll(); +} + +bool IVUsers::runOnLoop(Loop *l, LPPassManager &LPM) { + + L = l; + LI = &getAnalysis<LoopInfo>(); + DT = &getAnalysis<DominatorTree>(); + SE = &getAnalysis<ScalarEvolution>(); + + // Find all uses of induction variables in this loop, and categorize + // them by stride. Start by finding all of the PHI nodes in the header for + // this loop. If they are induction variables, inspect their uses. + for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) + AddUsersIfInteresting(I); + + return false; +} + +/// getReplacementExpr - Return a SCEV expression which computes the +/// value of the OperandValToReplace of the given IVStrideUse. +SCEVHandle IVUsers::getReplacementExpr(const IVStrideUse &U) const { + const Type *UseTy = U.getOperandValToReplace()->getType(); + // Start with zero. + SCEVHandle RetVal = SE->getIntegerSCEV(0, U.getParent()->Stride->getType()); + // Create the basic add recurrence. + RetVal = SE->getAddRecExpr(RetVal, U.getParent()->Stride, L); + // Add the offset in a separate step, because it may be loop-variant. + RetVal = SE->getAddExpr(RetVal, U.getOffset()); + // For uses of post-incremented values, add an extra stride to compute + // the actual replacement value. + if (U.isUseOfPostIncrementedValue()) + RetVal = SE->getAddExpr(RetVal, U.getParent()->Stride); + // Evaluate the expression out of the loop, if possible. + if (!L->contains(U.getUser()->getParent())) { + SCEVHandle ExitVal = SE->getSCEVAtScope(RetVal, L->getParentLoop()); + if (!isa<SCEVCouldNotCompute>(ExitVal) && ExitVal->isLoopInvariant(L)) + RetVal = ExitVal; + } + // Promote the result to the type of the use. + if (SE->getTypeSizeInBits(RetVal->getType()) != + SE->getTypeSizeInBits(UseTy)) { + if (U.isSigned()) + RetVal = SE->getSignExtendExpr(RetVal, UseTy); + else + RetVal = SE->getZeroExtendExpr(RetVal, UseTy); + } + return RetVal; +} + +void IVUsers::print(raw_ostream &OS, const Module *M) const { + OS << "IV Users for loop "; + WriteAsOperand(OS, L->getHeader(), false); + if (SE->hasLoopInvariantBackedgeTakenCount(L)) { + OS << " with backedge-taken count " + << *SE->getBackedgeTakenCount(L); + } + OS << ":\n"; + + for (unsigned Stride = 0, e = StrideOrder.size(); Stride != e; ++Stride) { + std::map<SCEVHandle, IVUsersOfOneStride*>::const_iterator SI = + IVUsesByStride.find(StrideOrder[Stride]); + assert(SI != IVUsesByStride.end() && "Stride doesn't exist!"); + OS << " Stride " << *SI->first->getType() << " " << *SI->first << ":\n"; + + for (ilist<IVStrideUse>::const_iterator UI = SI->second->Users.begin(), + E = SI->second->Users.end(); UI != E; ++UI) { + OS << " "; + WriteAsOperand(OS, UI->getOperandValToReplace(), false); + OS << " = "; + OS << *getReplacementExpr(*UI); + if (UI->isUseOfPostIncrementedValue()) + OS << " (post-inc)"; + OS << " in "; + UI->getUser()->print(OS); + } + } +} + +void IVUsers::print(std::ostream &o, const Module *M) const { + raw_os_ostream OS(o); + print(OS, M); +} + +void IVUsers::dump() const { + print(errs()); +} + +void IVUsers::releaseMemory() { + IVUsesByStride.clear(); + StrideOrder.clear(); + Processed.clear(); +} + +void IVStrideUse::deleted() { + // Remove this user from the list. + Parent->Users.erase(this); + // this now dangles! +} diff --git a/lib/Analysis/InstCount.cpp b/lib/Analysis/InstCount.cpp new file mode 100644 index 0000000..2dea7b3 --- /dev/null +++ b/lib/Analysis/InstCount.cpp @@ -0,0 +1,86 @@ +//===-- InstCount.cpp - Collects the count of all instructions ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass collects the count of all instructions and reports them +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "instcount" +#include "llvm/Analysis/Passes.h" +#include "llvm/Pass.h" +#include "llvm/Function.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/InstVisitor.h" +#include "llvm/Support/Streams.h" +#include "llvm/ADT/Statistic.h" +#include <ostream> +using namespace llvm; + +STATISTIC(TotalInsts , "Number of instructions (of all types)"); +STATISTIC(TotalBlocks, "Number of basic blocks"); +STATISTIC(TotalFuncs , "Number of non-external functions"); +STATISTIC(TotalMemInst, "Number of memory instructions"); + +#define HANDLE_INST(N, OPCODE, CLASS) \ + STATISTIC(Num ## OPCODE ## Inst, "Number of " #OPCODE " insts"); + +#include "llvm/Instruction.def" + + +namespace { + class VISIBILITY_HIDDEN InstCount + : public FunctionPass, public InstVisitor<InstCount> { + friend class InstVisitor<InstCount>; + + void visitFunction (Function &F) { ++TotalFuncs; } + void visitBasicBlock(BasicBlock &BB) { ++TotalBlocks; } + +#define HANDLE_INST(N, OPCODE, CLASS) \ + void visit##OPCODE(CLASS &) { ++Num##OPCODE##Inst; ++TotalInsts; } + +#include "llvm/Instruction.def" + + void visitInstruction(Instruction &I) { + cerr << "Instruction Count does not know about " << I; + abort(); + } + public: + static char ID; // Pass identification, replacement for typeid + InstCount() : FunctionPass(&ID) {} + + virtual bool runOnFunction(Function &F); + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + } + virtual void print(std::ostream &O, const Module *M) const {} + + }; +} + +char InstCount::ID = 0; +static RegisterPass<InstCount> +X("instcount", "Counts the various types of Instructions", false, true); + +FunctionPass *llvm::createInstCountPass() { return new InstCount(); } + +// InstCount::run - This is the main Analysis entry point for a +// function. +// +bool InstCount::runOnFunction(Function &F) { + unsigned StartMemInsts = + NumGetElementPtrInst + NumLoadInst + NumStoreInst + NumCallInst + + NumInvokeInst + NumAllocaInst + NumMallocInst + NumFreeInst; + visit(F); + unsigned EndMemInsts = + NumGetElementPtrInst + NumLoadInst + NumStoreInst + NumCallInst + + NumInvokeInst + NumAllocaInst + NumMallocInst + NumFreeInst; + TotalMemInst += EndMemInsts-StartMemInsts; + return false; +} diff --git a/lib/Analysis/Interval.cpp b/lib/Analysis/Interval.cpp new file mode 100644 index 0000000..16b1947 --- /dev/null +++ b/lib/Analysis/Interval.cpp @@ -0,0 +1,57 @@ +//===- Interval.cpp - Interval class code ---------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains the definition of the Interval class, which represents a +// partition of a control flow graph of some kind. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Interval.h" +#include "llvm/BasicBlock.h" +#include "llvm/Support/CFG.h" +#include <algorithm> + +using namespace llvm; + +//===----------------------------------------------------------------------===// +// Interval Implementation +//===----------------------------------------------------------------------===// + +// isLoop - Find out if there is a back edge in this interval... +// +bool Interval::isLoop() const { + // There is a loop in this interval iff one of the predecessors of the header + // node lives in the interval. + for (::pred_iterator I = ::pred_begin(HeaderNode), E = ::pred_end(HeaderNode); + I != E; ++I) { + if (contains(*I)) return true; + } + return false; +} + + +void Interval::print(std::ostream &o) const { + o << "-------------------------------------------------------------\n" + << "Interval Contents:\n"; + + // Print out all of the basic blocks in the interval... + for (std::vector<BasicBlock*>::const_iterator I = Nodes.begin(), + E = Nodes.end(); I != E; ++I) + o << **I << "\n"; + + o << "Interval Predecessors:\n"; + for (std::vector<BasicBlock*>::const_iterator I = Predecessors.begin(), + E = Predecessors.end(); I != E; ++I) + o << **I << "\n"; + + o << "Interval Successors:\n"; + for (std::vector<BasicBlock*>::const_iterator I = Successors.begin(), + E = Successors.end(); I != E; ++I) + o << **I << "\n"; +} diff --git a/lib/Analysis/IntervalPartition.cpp b/lib/Analysis/IntervalPartition.cpp new file mode 100644 index 0000000..cb8a85d --- /dev/null +++ b/lib/Analysis/IntervalPartition.cpp @@ -0,0 +1,114 @@ +//===- IntervalPartition.cpp - Interval Partition module code -------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains the definition of the IntervalPartition class, which +// calculates and represent the interval partition of a function. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/IntervalIterator.h" +using namespace llvm; + +char IntervalPartition::ID = 0; +static RegisterPass<IntervalPartition> +X("intervals", "Interval Partition Construction", true, true); + +//===----------------------------------------------------------------------===// +// IntervalPartition Implementation +//===----------------------------------------------------------------------===// + +// releaseMemory - Reset state back to before function was analyzed +void IntervalPartition::releaseMemory() { + for (unsigned i = 0, e = Intervals.size(); i != e; ++i) + delete Intervals[i]; + IntervalMap.clear(); + Intervals.clear(); + RootInterval = 0; +} + +void IntervalPartition::print(std::ostream &O, const Module*) const { + for(unsigned i = 0, e = Intervals.size(); i != e; ++i) + Intervals[i]->print(O); +} + +// addIntervalToPartition - Add an interval to the internal list of intervals, +// and then add mappings from all of the basic blocks in the interval to the +// interval itself (in the IntervalMap). +// +void IntervalPartition::addIntervalToPartition(Interval *I) { + Intervals.push_back(I); + + // Add mappings for all of the basic blocks in I to the IntervalPartition + for (Interval::node_iterator It = I->Nodes.begin(), End = I->Nodes.end(); + It != End; ++It) + IntervalMap.insert(std::make_pair(*It, I)); +} + +// updatePredecessors - Interval generation only sets the successor fields of +// the interval data structures. After interval generation is complete, +// run through all of the intervals and propagate successor info as +// predecessor info. +// +void IntervalPartition::updatePredecessors(Interval *Int) { + BasicBlock *Header = Int->getHeaderNode(); + for (Interval::succ_iterator I = Int->Successors.begin(), + E = Int->Successors.end(); I != E; ++I) + getBlockInterval(*I)->Predecessors.push_back(Header); +} + +// IntervalPartition ctor - Build the first level interval partition for the +// specified function... +// +bool IntervalPartition::runOnFunction(Function &F) { + // Pass false to intervals_begin because we take ownership of it's memory + function_interval_iterator I = intervals_begin(&F, false); + assert(I != intervals_end(&F) && "No intervals in function!?!?!"); + + addIntervalToPartition(RootInterval = *I); + + ++I; // After the first one... + + // Add the rest of the intervals to the partition. + for (function_interval_iterator E = intervals_end(&F); I != E; ++I) + addIntervalToPartition(*I); + + // Now that we know all of the successor information, propagate this to the + // predecessors for each block. + for (unsigned i = 0, e = Intervals.size(); i != e; ++i) + updatePredecessors(Intervals[i]); + return false; +} + + +// IntervalPartition ctor - Build a reduced interval partition from an +// existing interval graph. This takes an additional boolean parameter to +// distinguish it from a copy constructor. Always pass in false for now. +// +IntervalPartition::IntervalPartition(IntervalPartition &IP, bool) + : FunctionPass(&ID) { + assert(IP.getRootInterval() && "Cannot operate on empty IntervalPartitions!"); + + // Pass false to intervals_begin because we take ownership of it's memory + interval_part_interval_iterator I = intervals_begin(IP, false); + assert(I != intervals_end(IP) && "No intervals in interval partition!?!?!"); + + addIntervalToPartition(RootInterval = *I); + + ++I; // After the first one... + + // Add the rest of the intervals to the partition. + for (interval_part_interval_iterator E = intervals_end(IP); I != E; ++I) + addIntervalToPartition(*I); + + // Now that we know all of the successor information, propagate this to the + // predecessors for each block. + for (unsigned i = 0, e = Intervals.size(); i != e; ++i) + updatePredecessors(Intervals[i]); +} + diff --git a/lib/Analysis/LibCallAliasAnalysis.cpp b/lib/Analysis/LibCallAliasAnalysis.cpp new file mode 100644 index 0000000..971e6e7 --- /dev/null +++ b/lib/Analysis/LibCallAliasAnalysis.cpp @@ -0,0 +1,141 @@ +//===- LibCallAliasAnalysis.cpp - Implement AliasAnalysis for libcalls ----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the LibCallAliasAnalysis class. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LibCallAliasAnalysis.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/LibCallSemantics.h" +#include "llvm/Function.h" +#include "llvm/Pass.h" +#include "llvm/Target/TargetData.h" +using namespace llvm; + +// Register this pass... +char LibCallAliasAnalysis::ID = 0; +static RegisterPass<LibCallAliasAnalysis> +X("libcall-aa", "LibCall Alias Analysis", false, true); + +// Declare that we implement the AliasAnalysis interface +static RegisterAnalysisGroup<AliasAnalysis> Y(X); + +FunctionPass *llvm::createLibCallAliasAnalysisPass(LibCallInfo *LCI) { + return new LibCallAliasAnalysis(LCI); +} + +LibCallAliasAnalysis::~LibCallAliasAnalysis() { + delete LCI; +} + +void LibCallAliasAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { + AliasAnalysis::getAnalysisUsage(AU); + AU.addRequired<TargetData>(); + AU.setPreservesAll(); // Does not transform code +} + + + +/// AnalyzeLibCallDetails - Given a call to a function with the specified +/// LibCallFunctionInfo, see if we can improve the mod/ref footprint of the call +/// vs the specified pointer/size. +AliasAnalysis::ModRefResult +LibCallAliasAnalysis::AnalyzeLibCallDetails(const LibCallFunctionInfo *FI, + CallSite CS, Value *P, + unsigned Size) { + // If we have a function, check to see what kind of mod/ref effects it + // has. Start by including any info globally known about the function. + AliasAnalysis::ModRefResult MRInfo = FI->UniversalBehavior; + if (MRInfo == NoModRef) return MRInfo; + + // If that didn't tell us that the function is 'readnone', check to see + // if we have detailed info and if 'P' is any of the locations we know + // about. + const LibCallFunctionInfo::LocationMRInfo *Details = FI->LocationDetails; + if (Details == 0) + return MRInfo; + + // If the details array is of the 'DoesNot' kind, we only know something if + // the pointer is a match for one of the locations in 'Details'. If we find a + // match, we can prove some interactions cannot happen. + // + if (FI->DetailsType == LibCallFunctionInfo::DoesNot) { + // Find out if the pointer refers to a known location. + for (unsigned i = 0; Details[i].LocationID != ~0U; ++i) { + const LibCallLocationInfo &Loc = + LCI->getLocationInfo(Details[i].LocationID); + LibCallLocationInfo::LocResult Res = Loc.isLocation(CS, P, Size); + if (Res != LibCallLocationInfo::Yes) continue; + + // If we find a match against a location that we 'do not' interact with, + // learn this info into MRInfo. + return ModRefResult(MRInfo & ~Details[i].MRInfo); + } + return MRInfo; + } + + // If the details are of the 'DoesOnly' sort, we know something if the pointer + // is a match for one of the locations in 'Details'. Also, if we can prove + // that the pointers is *not* one of the locations in 'Details', we know that + // the call is NoModRef. + assert(FI->DetailsType == LibCallFunctionInfo::DoesOnly); + + // Find out if the pointer refers to a known location. + bool NoneMatch = true; + for (unsigned i = 0; Details[i].LocationID != ~0U; ++i) { + const LibCallLocationInfo &Loc = + LCI->getLocationInfo(Details[i].LocationID); + LibCallLocationInfo::LocResult Res = Loc.isLocation(CS, P, Size); + if (Res == LibCallLocationInfo::No) continue; + + // If we don't know if this pointer points to the location, then we have to + // assume it might alias in some case. + if (Res == LibCallLocationInfo::Unknown) { + NoneMatch = false; + continue; + } + + // If we know that this pointer definitely is pointing into the location, + // merge in this information. + return ModRefResult(MRInfo & Details[i].MRInfo); + } + + // If we found that the pointer is guaranteed to not match any of the + // locations in our 'DoesOnly' rule, then we know that the pointer must point + // to some other location. Since the libcall doesn't mod/ref any other + // locations, return NoModRef. + if (NoneMatch) + return NoModRef; + + // Otherwise, return any other info gained so far. + return MRInfo; +} + +// getModRefInfo - Check to see if the specified callsite can clobber the +// specified memory object. +// +AliasAnalysis::ModRefResult +LibCallAliasAnalysis::getModRefInfo(CallSite CS, Value *P, unsigned Size) { + ModRefResult MRInfo = ModRef; + + // If this is a direct call to a function that LCI knows about, get the + // information about the runtime function. + if (LCI) { + if (Function *F = CS.getCalledFunction()) { + if (const LibCallFunctionInfo *FI = LCI->getFunctionInfo(F)) { + MRInfo = ModRefResult(MRInfo & AnalyzeLibCallDetails(FI, CS, P, Size)); + if (MRInfo == NoModRef) return NoModRef; + } + } + } + + // The AliasAnalysis base class has some smarts, lets use them. + return (ModRefResult)(MRInfo | AliasAnalysis::getModRefInfo(CS, P, Size)); +} diff --git a/lib/Analysis/LibCallSemantics.cpp b/lib/Analysis/LibCallSemantics.cpp new file mode 100644 index 0000000..2985047 --- /dev/null +++ b/lib/Analysis/LibCallSemantics.cpp @@ -0,0 +1,65 @@ +//===- LibCallSemantics.cpp - Describe library semantics ------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements interfaces that can be used to describe language +// specific runtime library interfaces (e.g. libc, libm, etc) to LLVM +// optimizers. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LibCallSemantics.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Function.h" +using namespace llvm; + +/// getMap - This impl pointer in ~LibCallInfo is actually a StringMap. This +/// helper does the cast. +static StringMap<const LibCallFunctionInfo*> *getMap(void *Ptr) { + return static_cast<StringMap<const LibCallFunctionInfo*> *>(Ptr); +} + +LibCallInfo::~LibCallInfo() { + delete getMap(Impl); +} + +const LibCallLocationInfo &LibCallInfo::getLocationInfo(unsigned LocID) const { + // Get location info on the first call. + if (NumLocations == 0) + NumLocations = getLocationInfo(Locations); + + assert(LocID < NumLocations && "Invalid location ID!"); + return Locations[LocID]; +} + + +/// getFunctionInfo - Return the LibCallFunctionInfo object corresponding to +/// the specified function if we have it. If not, return null. +const LibCallFunctionInfo *LibCallInfo::getFunctionInfo(Function *F) const { + StringMap<const LibCallFunctionInfo*> *Map = getMap(Impl); + + /// If this is the first time we are querying for this info, lazily construct + /// the StringMap to index it. + if (Map == 0) { + Impl = Map = new StringMap<const LibCallFunctionInfo*>(); + + const LibCallFunctionInfo *Array = getFunctionInfoArray(); + if (Array == 0) return 0; + + // We now have the array of entries. Populate the StringMap. + for (unsigned i = 0; Array[i].Name; ++i) + (*Map)[Array[i].Name] = Array+i; + } + + // Look up this function in the string map. + const char *ValueName = F->getNameStart(); + StringMap<const LibCallFunctionInfo*>::iterator I = + Map->find(ValueName, ValueName+F->getNameLen()); + return I != Map->end() ? I->second : 0; +} + diff --git a/lib/Analysis/LiveValues.cpp b/lib/Analysis/LiveValues.cpp new file mode 100644 index 0000000..2bbe98a --- /dev/null +++ b/lib/Analysis/LiveValues.cpp @@ -0,0 +1,191 @@ +//===- LiveValues.cpp - Liveness information for LLVM IR Values. ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the implementation for the LLVM IR Value liveness +// analysis pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LiveValues.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/LoopInfo.h" +using namespace llvm; + +FunctionPass *llvm::createLiveValuesPass() { return new LiveValues(); } + +char LiveValues::ID = 0; +static RegisterPass<LiveValues> +X("live-values", "Value Liveness Analysis", false, true); + +LiveValues::LiveValues() : FunctionPass(&ID) {} + +void LiveValues::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<DominatorTree>(); + AU.addRequired<LoopInfo>(); + AU.setPreservesAll(); +} + +bool LiveValues::runOnFunction(Function &F) { + DT = &getAnalysis<DominatorTree>(); + LI = &getAnalysis<LoopInfo>(); + + // This pass' values are computed lazily, so there's nothing to do here. + + return false; +} + +void LiveValues::releaseMemory() { + Memos.clear(); +} + +/// isUsedInBlock - Test if the given value is used in the given block. +/// +bool LiveValues::isUsedInBlock(const Value *V, const BasicBlock *BB) { + Memo &M = getMemo(V); + return M.Used.count(BB); +} + +/// isLiveThroughBlock - Test if the given value is known to be +/// live-through the given block, meaning that the block is properly +/// dominated by the value's definition, and there exists a block +/// reachable from it that contains a use. This uses a conservative +/// approximation that errs on the side of returning false. +/// +bool LiveValues::isLiveThroughBlock(const Value *V, + const BasicBlock *BB) { + Memo &M = getMemo(V); + return M.LiveThrough.count(BB); +} + +/// isKilledInBlock - Test if the given value is known to be killed in +/// the given block, meaning that the block contains a use of the value, +/// and no blocks reachable from the block contain a use. This uses a +/// conservative approximation that errs on the side of returning false. +/// +bool LiveValues::isKilledInBlock(const Value *V, const BasicBlock *BB) { + Memo &M = getMemo(V); + return M.Killed.count(BB); +} + +/// getMemo - Retrieve an existing Memo for the given value if one +/// is available, otherwise compute a new one. +/// +LiveValues::Memo &LiveValues::getMemo(const Value *V) { + DenseMap<const Value *, Memo>::iterator I = Memos.find(V); + if (I != Memos.end()) + return I->second; + return compute(V); +} + +/// getImmediateDominator - A handy utility for the specific DominatorTree +/// query that we need here. +/// +static const BasicBlock *getImmediateDominator(const BasicBlock *BB, + const DominatorTree *DT) { + DomTreeNode *Node = DT->getNode(const_cast<BasicBlock *>(BB))->getIDom(); + return Node ? Node->getBlock() : 0; +} + +/// compute - Compute a new Memo for the given value. +/// +LiveValues::Memo &LiveValues::compute(const Value *V) { + Memo &M = Memos[V]; + + // Determine the block containing the definition. + const BasicBlock *DefBB; + // Instructions define values with meaningful live ranges. + if (const Instruction *I = dyn_cast<Instruction>(V)) + DefBB = I->getParent(); + // Arguments can be analyzed as values defined in the entry block. + else if (const Argument *A = dyn_cast<Argument>(V)) + DefBB = &A->getParent()->getEntryBlock(); + // Constants and other things aren't meaningful here, so just + // return having computed an empty Memo so that we don't come + // here again. The assumption here is that client code won't + // be asking about such values very often. + else + return M; + + // Determine if the value is defined inside a loop. This is used + // to track whether the value is ever used outside the loop, so + // it'll be set to null if the value is either not defined in a + // loop or used outside the loop in which it is defined. + const Loop *L = LI->getLoopFor(DefBB); + + // Track whether the value is used anywhere outside of the block + // in which it is defined. + bool LiveOutOfDefBB = false; + + // Examine each use of the value. + for (Value::use_const_iterator I = V->use_begin(), E = V->use_end(); + I != E; ++I) { + const User *U = *I; + const BasicBlock *UseBB = cast<Instruction>(U)->getParent(); + + // Note the block in which this use occurs. + M.Used.insert(UseBB); + + // If the use block doesn't have successors, the value can be + // considered killed. + if (succ_begin(UseBB) == succ_end(UseBB)) + M.Killed.insert(UseBB); + + // Observe whether the value is used outside of the loop in which + // it is defined. Switch to an enclosing loop if necessary. + for (; L; L = L->getParentLoop()) + if (L->contains(UseBB)) + break; + + // Search for live-through blocks. + const BasicBlock *BB; + if (const PHINode *PHI = dyn_cast<PHINode>(U)) { + // For PHI nodes, start the search at the incoming block paired with the + // incoming value, which must be dominated by the definition. + unsigned Num = PHI->getIncomingValueNumForOperand(I.getOperandNo()); + BB = PHI->getIncomingBlock(Num); + + // A PHI-node use means the value is live-out of it's defining block + // even if that block also contains the only use. + LiveOutOfDefBB = true; + } else { + // Otherwise just start the search at the use. + BB = UseBB; + + // Note if the use is outside the defining block. + LiveOutOfDefBB |= UseBB != DefBB; + } + + // Climb the immediate dominator tree from the use to the definition + // and mark all intermediate blocks as live-through. + for (; BB != DefBB; BB = getImmediateDominator(BB, DT)) { + if (BB != UseBB && !M.LiveThrough.insert(BB)) + break; + } + } + + // If the value is defined inside a loop and is not live outside + // the loop, then each exit block of the loop in which the value + // is used is a kill block. + if (L) { + SmallVector<BasicBlock *, 4> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { + const BasicBlock *ExitingBlock = ExitingBlocks[i]; + if (M.Used.count(ExitingBlock)) + M.Killed.insert(ExitingBlock); + } + } + + // If the value was never used outside the the block in which it was + // defined, it's killed in that block. + if (!LiveOutOfDefBB) + M.Killed.insert(DefBB); + + return M; +} diff --git a/lib/Analysis/LoopInfo.cpp b/lib/Analysis/LoopInfo.cpp new file mode 100644 index 0000000..de6480a --- /dev/null +++ b/lib/Analysis/LoopInfo.cpp @@ -0,0 +1,50 @@ +//===- LoopInfo.cpp - Natural Loop Calculator -----------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the LoopInfo class that is used to identify natural loops +// and determine the loop depth of various nodes of the CFG. Note that the +// loops identified may actually be several natural loops that share the same +// header node... not just a single natural loop. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Constants.h" +#include "llvm/Instructions.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Streams.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include <algorithm> +#include <ostream> +using namespace llvm; + +char LoopInfo::ID = 0; +static RegisterPass<LoopInfo> +X("loops", "Natural Loop Information", true, true); + +//===----------------------------------------------------------------------===// +// Loop implementation +// + +//===----------------------------------------------------------------------===// +// LoopInfo implementation +// +bool LoopInfo::runOnFunction(Function &) { + releaseMemory(); + LI->Calculate(getAnalysis<DominatorTree>().getBase()); // Update + return false; +} + +void LoopInfo::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<DominatorTree>(); +} diff --git a/lib/Analysis/LoopPass.cpp b/lib/Analysis/LoopPass.cpp new file mode 100644 index 0000000..08c25f4 --- /dev/null +++ b/lib/Analysis/LoopPass.cpp @@ -0,0 +1,340 @@ +//===- LoopPass.cpp - Loop Pass and Loop Pass Manager ---------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements LoopPass and LPPassManager. All loop optimization +// and transformation passes are derived from LoopPass. LPPassManager is +// responsible for managing LoopPasses. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopPass.h" +using namespace llvm; + +//===----------------------------------------------------------------------===// +// LPPassManager +// + +char LPPassManager::ID = 0; +/// LPPassManager manages FPPassManagers and CalLGraphSCCPasses. + +LPPassManager::LPPassManager(int Depth) + : FunctionPass(&ID), PMDataManager(Depth) { + skipThisLoop = false; + redoThisLoop = false; + LI = NULL; + CurrentLoop = NULL; +} + +/// Delete loop from the loop queue and loop hierarchy (LoopInfo). +void LPPassManager::deleteLoopFromQueue(Loop *L) { + + if (Loop *ParentLoop = L->getParentLoop()) { // Not a top-level loop. + // Reparent all of the blocks in this loop. Since BBLoop had a parent, + // they are now all in it. + for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); + I != E; ++I) + if (LI->getLoopFor(*I) == L) // Don't change blocks in subloops. + LI->changeLoopFor(*I, ParentLoop); + + // Remove the loop from its parent loop. + for (Loop::iterator I = ParentLoop->begin(), E = ParentLoop->end();; + ++I) { + assert(I != E && "Couldn't find loop"); + if (*I == L) { + ParentLoop->removeChildLoop(I); + break; + } + } + + // Move all subloops into the parent loop. + while (!L->empty()) + ParentLoop->addChildLoop(L->removeChildLoop(L->end()-1)); + } else { + // Reparent all of the blocks in this loop. Since BBLoop had no parent, + // they no longer in a loop at all. + + for (unsigned i = 0; i != L->getBlocks().size(); ++i) { + // Don't change blocks in subloops. + if (LI->getLoopFor(L->getBlocks()[i]) == L) { + LI->removeBlock(L->getBlocks()[i]); + --i; + } + } + + // Remove the loop from the top-level LoopInfo object. + for (LoopInfo::iterator I = LI->begin(), E = LI->end();; ++I) { + assert(I != E && "Couldn't find loop"); + if (*I == L) { + LI->removeLoop(I); + break; + } + } + + // Move all of the subloops to the top-level. + while (!L->empty()) + LI->addTopLevelLoop(L->removeChildLoop(L->end()-1)); + } + + delete L; + + // If L is current loop then skip rest of the passes and let + // runOnFunction remove L from LQ. Otherwise, remove L from LQ now + // and continue applying other passes on CurrentLoop. + if (CurrentLoop == L) { + skipThisLoop = true; + return; + } + + for (std::deque<Loop *>::iterator I = LQ.begin(), + E = LQ.end(); I != E; ++I) { + if (*I == L) { + LQ.erase(I); + break; + } + } +} + +// Inset loop into loop nest (LoopInfo) and loop queue (LQ). +void LPPassManager::insertLoop(Loop *L, Loop *ParentLoop) { + + assert (CurrentLoop != L && "Cannot insert CurrentLoop"); + + // Insert into loop nest + if (ParentLoop) + ParentLoop->addChildLoop(L); + else + LI->addTopLevelLoop(L); + + // Insert L into loop queue + if (L == CurrentLoop) + redoLoop(L); + else if (!ParentLoop) + // This is top level loop. + LQ.push_front(L); + else { + // Insert L after ParentLoop + for (std::deque<Loop *>::iterator I = LQ.begin(), + E = LQ.end(); I != E; ++I) { + if (*I == ParentLoop) { + // deque does not support insert after. + ++I; + LQ.insert(I, 1, L); + break; + } + } + } +} + +// Reoptimize this loop. LPPassManager will re-insert this loop into the +// queue. This allows LoopPass to change loop nest for the loop. This +// utility may send LPPassManager into infinite loops so use caution. +void LPPassManager::redoLoop(Loop *L) { + assert (CurrentLoop == L && "Can redo only CurrentLoop"); + redoThisLoop = true; +} + +/// cloneBasicBlockSimpleAnalysis - Invoke cloneBasicBlockAnalysis hook for +/// all loop passes. +void LPPassManager::cloneBasicBlockSimpleAnalysis(BasicBlock *From, + BasicBlock *To, Loop *L) { + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + Pass *P = getContainedPass(Index); + LoopPass *LP = dynamic_cast<LoopPass *>(P); + LP->cloneBasicBlockAnalysis(From, To, L); + } +} + +/// deleteSimpleAnalysisValue - Invoke deleteAnalysisValue hook for all passes. +void LPPassManager::deleteSimpleAnalysisValue(Value *V, Loop *L) { + if (BasicBlock *BB = dyn_cast<BasicBlock>(V)) { + for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE; + ++BI) { + Instruction &I = *BI; + deleteSimpleAnalysisValue(&I, L); + } + } + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + Pass *P = getContainedPass(Index); + LoopPass *LP = dynamic_cast<LoopPass *>(P); + LP->deleteAnalysisValue(V, L); + } +} + + +// Recurse through all subloops and all loops into LQ. +static void addLoopIntoQueue(Loop *L, std::deque<Loop *> &LQ) { + LQ.push_back(L); + for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) + addLoopIntoQueue(*I, LQ); +} + +/// Pass Manager itself does not invalidate any analysis info. +void LPPassManager::getAnalysisUsage(AnalysisUsage &Info) const { + // LPPassManager needs LoopInfo. In the long term LoopInfo class will + // become part of LPPassManager. + Info.addRequired<LoopInfo>(); + Info.setPreservesAll(); +} + +/// run - Execute all of the passes scheduled for execution. Keep track of +/// whether any of the passes modifies the function, and if so, return true. +bool LPPassManager::runOnFunction(Function &F) { + LI = &getAnalysis<LoopInfo>(); + bool Changed = false; + + // Collect inherited analysis from Module level pass manager. + populateInheritedAnalysis(TPM->activeStack); + + // Populate Loop Queue + for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) + addLoopIntoQueue(*I, LQ); + + // Initialization + for (std::deque<Loop *>::const_iterator I = LQ.begin(), E = LQ.end(); + I != E; ++I) { + Loop *L = *I; + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + Pass *P = getContainedPass(Index); + LoopPass *LP = dynamic_cast<LoopPass *>(P); + if (LP) + Changed |= LP->doInitialization(L, *this); + } + } + + // Walk Loops + while (!LQ.empty()) { + + CurrentLoop = LQ.back(); + skipThisLoop = false; + redoThisLoop = false; + + // Run all passes on current SCC + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + Pass *P = getContainedPass(Index); + + dumpPassInfo(P, EXECUTION_MSG, ON_LOOP_MSG, ""); + dumpRequiredSet(P); + + initializeAnalysisImpl(P); + + LoopPass *LP = dynamic_cast<LoopPass *>(P); + { + PassManagerPrettyStackEntry X(LP, *CurrentLoop->getHeader()); + StartPassTimer(P); + assert(LP && "Invalid LPPassManager member"); + Changed |= LP->runOnLoop(CurrentLoop, *this); + StopPassTimer(P); + } + + if (Changed) + dumpPassInfo(P, MODIFICATION_MSG, ON_LOOP_MSG, ""); + dumpPreservedSet(P); + + verifyPreservedAnalysis(LP); + removeNotPreservedAnalysis(P); + recordAvailableAnalysis(P); + removeDeadPasses(P, "", ON_LOOP_MSG); + + // If dominator information is available then verify the info if requested. + verifyDomInfo(*LP, F); + + if (skipThisLoop) + // Do not run other passes on this loop. + break; + } + + // Pop the loop from queue after running all passes. + LQ.pop_back(); + + if (redoThisLoop) + LQ.push_back(CurrentLoop); + } + + // Finalization + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + Pass *P = getContainedPass(Index); + LoopPass *LP = dynamic_cast <LoopPass *>(P); + if (LP) + Changed |= LP->doFinalization(); + } + + return Changed; +} + +/// Print passes managed by this manager +void LPPassManager::dumpPassStructure(unsigned Offset) { + llvm::cerr << std::string(Offset*2, ' ') << "Loop Pass Manager\n"; + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + Pass *P = getContainedPass(Index); + P->dumpPassStructure(Offset + 1); + dumpLastUses(P, Offset+1); + } +} + + +//===----------------------------------------------------------------------===// +// LoopPass + +// Check if this pass is suitable for the current LPPassManager, if +// available. This pass P is not suitable for a LPPassManager if P +// is not preserving higher level analysis info used by other +// LPPassManager passes. In such case, pop LPPassManager from the +// stack. This will force assignPassManager() to create new +// LPPassManger as expected. +void LoopPass::preparePassManager(PMStack &PMS) { + + // Find LPPassManager + while (!PMS.empty() && + PMS.top()->getPassManagerType() > PMT_LoopPassManager) + PMS.pop(); + + LPPassManager *LPPM = dynamic_cast<LPPassManager *>(PMS.top()); + + // If this pass is destroying high level information that is used + // by other passes that are managed by LPM then do not insert + // this pass in current LPM. Use new LPPassManager. + if (LPPM && !LPPM->preserveHigherLevelAnalysis(this)) + PMS.pop(); +} + +/// Assign pass manager to manage this pass. +void LoopPass::assignPassManager(PMStack &PMS, + PassManagerType PreferredType) { + // Find LPPassManager + while (!PMS.empty() && + PMS.top()->getPassManagerType() > PMT_LoopPassManager) + PMS.pop(); + + LPPassManager *LPPM = dynamic_cast<LPPassManager *>(PMS.top()); + + // Create new Loop Pass Manager if it does not exist. + if (!LPPM) { + + assert (!PMS.empty() && "Unable to create Loop Pass Manager"); + PMDataManager *PMD = PMS.top(); + + // [1] Create new Call Graph Pass Manager + LPPM = new LPPassManager(PMD->getDepth() + 1); + LPPM->populateInheritedAnalysis(PMS); + + // [2] Set up new manager's top level manager + PMTopLevelManager *TPM = PMD->getTopLevelManager(); + TPM->addIndirectPassManager(LPPM); + + // [3] Assign manager to manage this new manager. This may create + // and push new managers into PMS + Pass *P = dynamic_cast<Pass *>(LPPM); + TPM->schedulePass(P); + + // [4] Push new manager into PMS + PMS.push(LPPM); + } + + LPPM->add(this); +} diff --git a/lib/Analysis/LoopVR.cpp b/lib/Analysis/LoopVR.cpp new file mode 100644 index 0000000..0a3d06b --- /dev/null +++ b/lib/Analysis/LoopVR.cpp @@ -0,0 +1,291 @@ +//===- LoopVR.cpp - Value Range analysis driven by loop information -------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// FIXME: What does this do? +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "loopvr" +#include "llvm/Analysis/LoopVR.h" +#include "llvm/Constants.h" +#include "llvm/Instructions.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +char LoopVR::ID = 0; +static RegisterPass<LoopVR> X("loopvr", "Loop Value Ranges", false, true); + +/// getRange - determine the range for a particular SCEV within a given Loop +ConstantRange LoopVR::getRange(SCEVHandle S, Loop *L, ScalarEvolution &SE) { + SCEVHandle T = SE.getBackedgeTakenCount(L); + if (isa<SCEVCouldNotCompute>(T)) + return ConstantRange(cast<IntegerType>(S->getType())->getBitWidth(), true); + + T = SE.getTruncateOrZeroExtend(T, S->getType()); + return getRange(S, T, SE); +} + +/// getRange - determine the range for a particular SCEV with a given trip count +ConstantRange LoopVR::getRange(SCEVHandle S, SCEVHandle T, ScalarEvolution &SE){ + + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) + return ConstantRange(C->getValue()->getValue()); + + ConstantRange FullSet(cast<IntegerType>(S->getType())->getBitWidth(), true); + + // {x,+,y,+,...z}. We detect overflow by checking the size of the set after + // summing the upper and lower. + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { + ConstantRange X = getRange(Add->getOperand(0), T, SE); + if (X.isFullSet()) return FullSet; + for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) { + ConstantRange Y = getRange(Add->getOperand(i), T, SE); + if (Y.isFullSet()) return FullSet; + + APInt Spread_X = X.getSetSize(), Spread_Y = Y.getSetSize(); + APInt NewLower = X.getLower() + Y.getLower(); + APInt NewUpper = X.getUpper() + Y.getUpper() - 1; + if (NewLower == NewUpper) + return FullSet; + + X = ConstantRange(NewLower, NewUpper); + if (X.getSetSize().ult(Spread_X) || X.getSetSize().ult(Spread_Y)) + return FullSet; // we've wrapped, therefore, full set. + } + return X; + } + + // {x,*,y,*,...,z}. In order to detect overflow, we use k*bitwidth where + // k is the number of terms being multiplied. + if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) { + ConstantRange X = getRange(Mul->getOperand(0), T, SE); + if (X.isFullSet()) return FullSet; + + const IntegerType *Ty = IntegerType::get(X.getBitWidth()); + const IntegerType *ExTy = IntegerType::get(X.getBitWidth() * + Mul->getNumOperands()); + ConstantRange XExt = X.zeroExtend(ExTy->getBitWidth()); + + for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) { + ConstantRange Y = getRange(Mul->getOperand(i), T, SE); + if (Y.isFullSet()) return FullSet; + + ConstantRange YExt = Y.zeroExtend(ExTy->getBitWidth()); + XExt = ConstantRange(XExt.getLower() * YExt.getLower(), + ((XExt.getUpper()-1) * (YExt.getUpper()-1)) + 1); + } + return XExt.truncate(Ty->getBitWidth()); + } + + // X smax Y smax ... Z is: range(smax(X_smin, Y_smin, ..., Z_smin), + // smax(X_smax, Y_smax, ..., Z_smax)) + // It doesn't matter if one of the SCEVs has FullSet because we're taking + // a maximum of the minimums across all of them. + if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) { + ConstantRange X = getRange(SMax->getOperand(0), T, SE); + if (X.isFullSet()) return FullSet; + + APInt smin = X.getSignedMin(), smax = X.getSignedMax(); + for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) { + ConstantRange Y = getRange(SMax->getOperand(i), T, SE); + smin = APIntOps::smax(smin, Y.getSignedMin()); + smax = APIntOps::smax(smax, Y.getSignedMax()); + } + if (smax + 1 == smin) return FullSet; + return ConstantRange(smin, smax + 1); + } + + // X umax Y umax ... Z is: range(umax(X_umin, Y_umin, ..., Z_umin), + // umax(X_umax, Y_umax, ..., Z_umax)) + // It doesn't matter if one of the SCEVs has FullSet because we're taking + // a maximum of the minimums across all of them. + if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) { + ConstantRange X = getRange(UMax->getOperand(0), T, SE); + if (X.isFullSet()) return FullSet; + + APInt umin = X.getUnsignedMin(), umax = X.getUnsignedMax(); + for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) { + ConstantRange Y = getRange(UMax->getOperand(i), T, SE); + umin = APIntOps::umax(umin, Y.getUnsignedMin()); + umax = APIntOps::umax(umax, Y.getUnsignedMax()); + } + if (umax + 1 == umin) return FullSet; + return ConstantRange(umin, umax + 1); + } + + // L udiv R. Luckily, there's only ever 2 sides to a udiv. + if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) { + ConstantRange L = getRange(UDiv->getLHS(), T, SE); + ConstantRange R = getRange(UDiv->getRHS(), T, SE); + if (L.isFullSet() && R.isFullSet()) return FullSet; + + if (R.getUnsignedMax() == 0) { + // RHS must be single-element zero. Return an empty set. + return ConstantRange(R.getBitWidth(), false); + } + + APInt Lower = L.getUnsignedMin().udiv(R.getUnsignedMax()); + + APInt Upper; + + if (R.getUnsignedMin() == 0) { + // Just because it contains zero, doesn't mean it will also contain one. + // Use maximalIntersectWith to get the right behaviour. + ConstantRange NotZero(APInt(L.getBitWidth(), 1), + APInt::getNullValue(L.getBitWidth())); + R = R.maximalIntersectWith(NotZero); + } + + // But, the maximal intersection might still include zero. If it does, then + // we know it also included one. + if (R.contains(APInt::getNullValue(L.getBitWidth()))) + Upper = L.getUnsignedMax(); + else + Upper = L.getUnsignedMax().udiv(R.getUnsignedMin()); + + return ConstantRange(Lower, Upper); + } + + // ConstantRange already implements the cast operators. + + if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) { + T = SE.getTruncateOrZeroExtend(T, ZExt->getOperand()->getType()); + ConstantRange X = getRange(ZExt->getOperand(), T, SE); + return X.zeroExtend(cast<IntegerType>(ZExt->getType())->getBitWidth()); + } + + if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) { + T = SE.getTruncateOrZeroExtend(T, SExt->getOperand()->getType()); + ConstantRange X = getRange(SExt->getOperand(), T, SE); + return X.signExtend(cast<IntegerType>(SExt->getType())->getBitWidth()); + } + + if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) { + T = SE.getTruncateOrZeroExtend(T, Trunc->getOperand()->getType()); + ConstantRange X = getRange(Trunc->getOperand(), T, SE); + if (X.isFullSet()) return FullSet; + return X.truncate(cast<IntegerType>(Trunc->getType())->getBitWidth()); + } + + if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) { + const SCEVConstant *Trip = dyn_cast<SCEVConstant>(T); + if (!Trip) return FullSet; + + if (AddRec->isAffine()) { + SCEVHandle StartHandle = AddRec->getStart(); + SCEVHandle StepHandle = AddRec->getOperand(1); + + const SCEVConstant *Step = dyn_cast<SCEVConstant>(StepHandle); + if (!Step) return FullSet; + + uint32_t ExWidth = 2 * Trip->getValue()->getBitWidth(); + APInt TripExt = Trip->getValue()->getValue(); TripExt.zext(ExWidth); + APInt StepExt = Step->getValue()->getValue(); StepExt.zext(ExWidth); + if ((TripExt * StepExt).ugt(APInt::getLowBitsSet(ExWidth, ExWidth >> 1))) + return FullSet; + + SCEVHandle EndHandle = SE.getAddExpr(StartHandle, + SE.getMulExpr(T, StepHandle)); + const SCEVConstant *Start = dyn_cast<SCEVConstant>(StartHandle); + const SCEVConstant *End = dyn_cast<SCEVConstant>(EndHandle); + if (!Start || !End) return FullSet; + + const APInt &StartInt = Start->getValue()->getValue(); + const APInt &EndInt = End->getValue()->getValue(); + const APInt &StepInt = Step->getValue()->getValue(); + + if (StepInt.isNegative()) { + if (EndInt == StartInt + 1) return FullSet; + return ConstantRange(EndInt, StartInt + 1); + } else { + if (StartInt == EndInt + 1) return FullSet; + return ConstantRange(StartInt, EndInt + 1); + } + } + } + + // TODO: non-affine addrec, udiv, SCEVUnknown (narrowed from elsewhere)? + + return FullSet; +} + +bool LoopVR::runOnFunction(Function &F) { Map.clear(); return false; } + +void LoopVR::print(std::ostream &os, const Module *) const { + raw_os_ostream OS(os); + for (std::map<Value *, ConstantRange *>::const_iterator I = Map.begin(), + E = Map.end(); I != E; ++I) { + OS << *I->first << ": " << *I->second << '\n'; + } +} + +void LoopVR::releaseMemory() { + for (std::map<Value *, ConstantRange *>::iterator I = Map.begin(), + E = Map.end(); I != E; ++I) { + delete I->second; + } + + Map.clear(); +} + +ConstantRange LoopVR::compute(Value *V) { + if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) + return ConstantRange(CI->getValue()); + + Instruction *I = dyn_cast<Instruction>(V); + if (!I) + return ConstantRange(cast<IntegerType>(V->getType())->getBitWidth(), false); + + LoopInfo &LI = getAnalysis<LoopInfo>(); + + Loop *L = LI.getLoopFor(I->getParent()); + if (!L || L->isLoopInvariant(I)) + return ConstantRange(cast<IntegerType>(V->getType())->getBitWidth(), false); + + ScalarEvolution &SE = getAnalysis<ScalarEvolution>(); + + SCEVHandle S = SE.getSCEV(I); + if (isa<SCEVUnknown>(S) || isa<SCEVCouldNotCompute>(S)) + return ConstantRange(cast<IntegerType>(V->getType())->getBitWidth(), false); + + return ConstantRange(getRange(S, L, SE)); +} + +ConstantRange LoopVR::get(Value *V) { + std::map<Value *, ConstantRange *>::iterator I = Map.find(V); + if (I == Map.end()) { + ConstantRange *CR = new ConstantRange(compute(V)); + Map[V] = CR; + return *CR; + } + + return *I->second; +} + +void LoopVR::remove(Value *V) { + std::map<Value *, ConstantRange *>::iterator I = Map.find(V); + if (I != Map.end()) { + delete I->second; + Map.erase(I); + } +} + +void LoopVR::narrow(Value *V, const ConstantRange &CR) { + if (CR.isFullSet()) return; + + std::map<Value *, ConstantRange *>::iterator I = Map.find(V); + if (I == Map.end()) + Map[V] = new ConstantRange(CR); + else + Map[V] = new ConstantRange(Map[V]->maximalIntersectWith(CR)); +} diff --git a/lib/Analysis/Makefile b/lib/Analysis/Makefile new file mode 100644 index 0000000..4af6d35 --- /dev/null +++ b/lib/Analysis/Makefile @@ -0,0 +1,16 @@ +##===- lib/Analysis/Makefile -------------------------------*- Makefile -*-===## +# +# The LLVM Compiler Infrastructure +# +# This file is distributed under the University of Illinois Open Source +# License. See LICENSE.TXT for details. +# +##===----------------------------------------------------------------------===## + +LEVEL = ../.. +LIBRARYNAME = LLVMAnalysis +DIRS = IPA +BUILD_ARCHIVE = 1 + +include $(LEVEL)/Makefile.common + diff --git a/lib/Analysis/MemoryDependenceAnalysis.cpp b/lib/Analysis/MemoryDependenceAnalysis.cpp new file mode 100644 index 0000000..3b21029 --- /dev/null +++ b/lib/Analysis/MemoryDependenceAnalysis.cpp @@ -0,0 +1,1142 @@ +//===- MemoryDependenceAnalysis.cpp - Mem Deps Implementation --*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements an analysis that determines, for a given memory +// operation, what preceding memory operations it depends on. It builds on +// alias analysis information, and tries to provide a lazy, caching interface to +// a common kind of alias information query. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "memdep" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Constants.h" +#include "llvm/Instructions.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Function.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/PredIteratorCache.h" +#include "llvm/Support/Debug.h" +#include "llvm/Target/TargetData.h" +using namespace llvm; + +STATISTIC(NumCacheNonLocal, "Number of fully cached non-local responses"); +STATISTIC(NumCacheDirtyNonLocal, "Number of dirty cached non-local responses"); +STATISTIC(NumUncacheNonLocal, "Number of uncached non-local responses"); + +STATISTIC(NumCacheNonLocalPtr, + "Number of fully cached non-local ptr responses"); +STATISTIC(NumCacheDirtyNonLocalPtr, + "Number of cached, but dirty, non-local ptr responses"); +STATISTIC(NumUncacheNonLocalPtr, + "Number of uncached non-local ptr responses"); +STATISTIC(NumCacheCompleteNonLocalPtr, + "Number of block queries that were completely cached"); + +char MemoryDependenceAnalysis::ID = 0; + +// Register this pass... +static RegisterPass<MemoryDependenceAnalysis> X("memdep", + "Memory Dependence Analysis", false, true); + +MemoryDependenceAnalysis::MemoryDependenceAnalysis() +: FunctionPass(&ID), PredCache(0) { +} +MemoryDependenceAnalysis::~MemoryDependenceAnalysis() { +} + +/// Clean up memory in between runs +void MemoryDependenceAnalysis::releaseMemory() { + LocalDeps.clear(); + NonLocalDeps.clear(); + NonLocalPointerDeps.clear(); + ReverseLocalDeps.clear(); + ReverseNonLocalDeps.clear(); + ReverseNonLocalPtrDeps.clear(); + PredCache->clear(); +} + + + +/// getAnalysisUsage - Does not modify anything. It uses Alias Analysis. +/// +void MemoryDependenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequiredTransitive<AliasAnalysis>(); + AU.addRequiredTransitive<TargetData>(); +} + +bool MemoryDependenceAnalysis::runOnFunction(Function &) { + AA = &getAnalysis<AliasAnalysis>(); + TD = &getAnalysis<TargetData>(); + if (PredCache == 0) + PredCache.reset(new PredIteratorCache()); + return false; +} + +/// RemoveFromReverseMap - This is a helper function that removes Val from +/// 'Inst's set in ReverseMap. If the set becomes empty, remove Inst's entry. +template <typename KeyTy> +static void RemoveFromReverseMap(DenseMap<Instruction*, + SmallPtrSet<KeyTy, 4> > &ReverseMap, + Instruction *Inst, KeyTy Val) { + typename DenseMap<Instruction*, SmallPtrSet<KeyTy, 4> >::iterator + InstIt = ReverseMap.find(Inst); + assert(InstIt != ReverseMap.end() && "Reverse map out of sync?"); + bool Found = InstIt->second.erase(Val); + assert(Found && "Invalid reverse map!"); Found=Found; + if (InstIt->second.empty()) + ReverseMap.erase(InstIt); +} + + +/// getCallSiteDependencyFrom - Private helper for finding the local +/// dependencies of a call site. +MemDepResult MemoryDependenceAnalysis:: +getCallSiteDependencyFrom(CallSite CS, bool isReadOnlyCall, + BasicBlock::iterator ScanIt, BasicBlock *BB) { + // Walk backwards through the block, looking for dependencies + while (ScanIt != BB->begin()) { + Instruction *Inst = --ScanIt; + + // If this inst is a memory op, get the pointer it accessed + Value *Pointer = 0; + uint64_t PointerSize = 0; + if (StoreInst *S = dyn_cast<StoreInst>(Inst)) { + Pointer = S->getPointerOperand(); + PointerSize = TD->getTypeStoreSize(S->getOperand(0)->getType()); + } else if (VAArgInst *V = dyn_cast<VAArgInst>(Inst)) { + Pointer = V->getOperand(0); + PointerSize = TD->getTypeStoreSize(V->getType()); + } else if (FreeInst *F = dyn_cast<FreeInst>(Inst)) { + Pointer = F->getPointerOperand(); + + // FreeInsts erase the entire structure + PointerSize = ~0ULL; + } else if (isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) { + // Debug intrinsics don't cause dependences. + if (isa<DbgInfoIntrinsic>(Inst)) continue; + CallSite InstCS = CallSite::get(Inst); + // If these two calls do not interfere, look past it. + switch (AA->getModRefInfo(CS, InstCS)) { + case AliasAnalysis::NoModRef: + // If the two calls don't interact (e.g. InstCS is readnone) keep + // scanning. + continue; + case AliasAnalysis::Ref: + // If the two calls read the same memory locations and CS is a readonly + // function, then we have two cases: 1) the calls may not interfere with + // each other at all. 2) the calls may produce the same value. In case + // #1 we want to ignore the values, in case #2, we want to return Inst + // as a Def dependence. This allows us to CSE in cases like: + // X = strlen(P); + // memchr(...); + // Y = strlen(P); // Y = X + if (isReadOnlyCall) { + if (CS.getCalledFunction() != 0 && + CS.getCalledFunction() == InstCS.getCalledFunction()) + return MemDepResult::getDef(Inst); + // Ignore unrelated read/read call dependences. + continue; + } + // FALL THROUGH + default: + return MemDepResult::getClobber(Inst); + } + } else { + // Non-memory instruction. + continue; + } + + if (AA->getModRefInfo(CS, Pointer, PointerSize) != AliasAnalysis::NoModRef) + return MemDepResult::getClobber(Inst); + } + + // No dependence found. If this is the entry block of the function, it is a + // clobber, otherwise it is non-local. + if (BB != &BB->getParent()->getEntryBlock()) + return MemDepResult::getNonLocal(); + return MemDepResult::getClobber(ScanIt); +} + +/// getPointerDependencyFrom - Return the instruction on which a memory +/// location depends. If isLoad is true, this routine ignore may-aliases with +/// read-only operations. +MemDepResult MemoryDependenceAnalysis:: +getPointerDependencyFrom(Value *MemPtr, uint64_t MemSize, bool isLoad, + BasicBlock::iterator ScanIt, BasicBlock *BB) { + + // Walk backwards through the basic block, looking for dependencies. + while (ScanIt != BB->begin()) { + Instruction *Inst = --ScanIt; + + // Debug intrinsics don't cause dependences. + if (isa<DbgInfoIntrinsic>(Inst)) continue; + + // Values depend on loads if the pointers are must aliased. This means that + // a load depends on another must aliased load from the same value. + if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { + Value *Pointer = LI->getPointerOperand(); + uint64_t PointerSize = TD->getTypeStoreSize(LI->getType()); + + // If we found a pointer, check if it could be the same as our pointer. + AliasAnalysis::AliasResult R = + AA->alias(Pointer, PointerSize, MemPtr, MemSize); + if (R == AliasAnalysis::NoAlias) + continue; + + // May-alias loads don't depend on each other without a dependence. + if (isLoad && R == AliasAnalysis::MayAlias) + continue; + // Stores depend on may and must aliased loads, loads depend on must-alias + // loads. + return MemDepResult::getDef(Inst); + } + + if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + // If alias analysis can tell that this store is guaranteed to not modify + // the query pointer, ignore it. Use getModRefInfo to handle cases where + // the query pointer points to constant memory etc. + if (AA->getModRefInfo(SI, MemPtr, MemSize) == AliasAnalysis::NoModRef) + continue; + + // Ok, this store might clobber the query pointer. Check to see if it is + // a must alias: in this case, we want to return this as a def. + Value *Pointer = SI->getPointerOperand(); + uint64_t PointerSize = TD->getTypeStoreSize(SI->getOperand(0)->getType()); + + // If we found a pointer, check if it could be the same as our pointer. + AliasAnalysis::AliasResult R = + AA->alias(Pointer, PointerSize, MemPtr, MemSize); + + if (R == AliasAnalysis::NoAlias) + continue; + if (R == AliasAnalysis::MayAlias) + return MemDepResult::getClobber(Inst); + return MemDepResult::getDef(Inst); + } + + // If this is an allocation, and if we know that the accessed pointer is to + // the allocation, return Def. This means that there is no dependence and + // the access can be optimized based on that. For example, a load could + // turn into undef. + if (AllocationInst *AI = dyn_cast<AllocationInst>(Inst)) { + Value *AccessPtr = MemPtr->getUnderlyingObject(); + + if (AccessPtr == AI || + AA->alias(AI, 1, AccessPtr, 1) == AliasAnalysis::MustAlias) + return MemDepResult::getDef(AI); + continue; + } + + // See if this instruction (e.g. a call or vaarg) mod/ref's the pointer. + switch (AA->getModRefInfo(Inst, MemPtr, MemSize)) { + case AliasAnalysis::NoModRef: + // If the call has no effect on the queried pointer, just ignore it. + continue; + case AliasAnalysis::Ref: + // If the call is known to never store to the pointer, and if this is a + // load query, we can safely ignore it (scan past it). + if (isLoad) + continue; + // FALL THROUGH. + default: + // Otherwise, there is a potential dependence. Return a clobber. + return MemDepResult::getClobber(Inst); + } + } + + // No dependence found. If this is the entry block of the function, it is a + // clobber, otherwise it is non-local. + if (BB != &BB->getParent()->getEntryBlock()) + return MemDepResult::getNonLocal(); + return MemDepResult::getClobber(ScanIt); +} + +/// getDependency - Return the instruction on which a memory operation +/// depends. +MemDepResult MemoryDependenceAnalysis::getDependency(Instruction *QueryInst) { + Instruction *ScanPos = QueryInst; + + // Check for a cached result + MemDepResult &LocalCache = LocalDeps[QueryInst]; + + // If the cached entry is non-dirty, just return it. Note that this depends + // on MemDepResult's default constructing to 'dirty'. + if (!LocalCache.isDirty()) + return LocalCache; + + // Otherwise, if we have a dirty entry, we know we can start the scan at that + // instruction, which may save us some work. + if (Instruction *Inst = LocalCache.getInst()) { + ScanPos = Inst; + + RemoveFromReverseMap(ReverseLocalDeps, Inst, QueryInst); + } + + BasicBlock *QueryParent = QueryInst->getParent(); + + Value *MemPtr = 0; + uint64_t MemSize = 0; + + // Do the scan. + if (BasicBlock::iterator(QueryInst) == QueryParent->begin()) { + // No dependence found. If this is the entry block of the function, it is a + // clobber, otherwise it is non-local. + if (QueryParent != &QueryParent->getParent()->getEntryBlock()) + LocalCache = MemDepResult::getNonLocal(); + else + LocalCache = MemDepResult::getClobber(QueryInst); + } else if (StoreInst *SI = dyn_cast<StoreInst>(QueryInst)) { + // If this is a volatile store, don't mess around with it. Just return the + // previous instruction as a clobber. + if (SI->isVolatile()) + LocalCache = MemDepResult::getClobber(--BasicBlock::iterator(ScanPos)); + else { + MemPtr = SI->getPointerOperand(); + MemSize = TD->getTypeStoreSize(SI->getOperand(0)->getType()); + } + } else if (LoadInst *LI = dyn_cast<LoadInst>(QueryInst)) { + // If this is a volatile load, don't mess around with it. Just return the + // previous instruction as a clobber. + if (LI->isVolatile()) + LocalCache = MemDepResult::getClobber(--BasicBlock::iterator(ScanPos)); + else { + MemPtr = LI->getPointerOperand(); + MemSize = TD->getTypeStoreSize(LI->getType()); + } + } else if (isa<CallInst>(QueryInst) || isa<InvokeInst>(QueryInst)) { + CallSite QueryCS = CallSite::get(QueryInst); + bool isReadOnly = AA->onlyReadsMemory(QueryCS); + LocalCache = getCallSiteDependencyFrom(QueryCS, isReadOnly, ScanPos, + QueryParent); + } else if (FreeInst *FI = dyn_cast<FreeInst>(QueryInst)) { + MemPtr = FI->getPointerOperand(); + // FreeInsts erase the entire structure, not just a field. + MemSize = ~0UL; + } else { + // Non-memory instruction. + LocalCache = MemDepResult::getClobber(--BasicBlock::iterator(ScanPos)); + } + + // If we need to do a pointer scan, make it happen. + if (MemPtr) + LocalCache = getPointerDependencyFrom(MemPtr, MemSize, + isa<LoadInst>(QueryInst), + ScanPos, QueryParent); + + // Remember the result! + if (Instruction *I = LocalCache.getInst()) + ReverseLocalDeps[I].insert(QueryInst); + + return LocalCache; +} + +#ifndef NDEBUG +/// AssertSorted - This method is used when -debug is specified to verify that +/// cache arrays are properly kept sorted. +static void AssertSorted(MemoryDependenceAnalysis::NonLocalDepInfo &Cache, + int Count = -1) { + if (Count == -1) Count = Cache.size(); + if (Count == 0) return; + + for (unsigned i = 1; i != unsigned(Count); ++i) + assert(Cache[i-1] <= Cache[i] && "Cache isn't sorted!"); +} +#endif + +/// getNonLocalCallDependency - Perform a full dependency query for the +/// specified call, returning the set of blocks that the value is +/// potentially live across. The returned set of results will include a +/// "NonLocal" result for all blocks where the value is live across. +/// +/// This method assumes the instruction returns a "NonLocal" dependency +/// within its own block. +/// +/// This returns a reference to an internal data structure that may be +/// invalidated on the next non-local query or when an instruction is +/// removed. Clients must copy this data if they want it around longer than +/// that. +const MemoryDependenceAnalysis::NonLocalDepInfo & +MemoryDependenceAnalysis::getNonLocalCallDependency(CallSite QueryCS) { + assert(getDependency(QueryCS.getInstruction()).isNonLocal() && + "getNonLocalCallDependency should only be used on calls with non-local deps!"); + PerInstNLInfo &CacheP = NonLocalDeps[QueryCS.getInstruction()]; + NonLocalDepInfo &Cache = CacheP.first; + + /// DirtyBlocks - This is the set of blocks that need to be recomputed. In + /// the cached case, this can happen due to instructions being deleted etc. In + /// the uncached case, this starts out as the set of predecessors we care + /// about. + SmallVector<BasicBlock*, 32> DirtyBlocks; + + if (!Cache.empty()) { + // Okay, we have a cache entry. If we know it is not dirty, just return it + // with no computation. + if (!CacheP.second) { + NumCacheNonLocal++; + return Cache; + } + + // If we already have a partially computed set of results, scan them to + // determine what is dirty, seeding our initial DirtyBlocks worklist. + for (NonLocalDepInfo::iterator I = Cache.begin(), E = Cache.end(); + I != E; ++I) + if (I->second.isDirty()) + DirtyBlocks.push_back(I->first); + + // Sort the cache so that we can do fast binary search lookups below. + std::sort(Cache.begin(), Cache.end()); + + ++NumCacheDirtyNonLocal; + //cerr << "CACHED CASE: " << DirtyBlocks.size() << " dirty: " + // << Cache.size() << " cached: " << *QueryInst; + } else { + // Seed DirtyBlocks with each of the preds of QueryInst's block. + BasicBlock *QueryBB = QueryCS.getInstruction()->getParent(); + for (BasicBlock **PI = PredCache->GetPreds(QueryBB); *PI; ++PI) + DirtyBlocks.push_back(*PI); + NumUncacheNonLocal++; + } + + // isReadonlyCall - If this is a read-only call, we can be more aggressive. + bool isReadonlyCall = AA->onlyReadsMemory(QueryCS); + + SmallPtrSet<BasicBlock*, 64> Visited; + + unsigned NumSortedEntries = Cache.size(); + DEBUG(AssertSorted(Cache)); + + // Iterate while we still have blocks to update. + while (!DirtyBlocks.empty()) { + BasicBlock *DirtyBB = DirtyBlocks.back(); + DirtyBlocks.pop_back(); + + // Already processed this block? + if (!Visited.insert(DirtyBB)) + continue; + + // Do a binary search to see if we already have an entry for this block in + // the cache set. If so, find it. + DEBUG(AssertSorted(Cache, NumSortedEntries)); + NonLocalDepInfo::iterator Entry = + std::upper_bound(Cache.begin(), Cache.begin()+NumSortedEntries, + std::make_pair(DirtyBB, MemDepResult())); + if (Entry != Cache.begin() && prior(Entry)->first == DirtyBB) + --Entry; + + MemDepResult *ExistingResult = 0; + if (Entry != Cache.begin()+NumSortedEntries && + Entry->first == DirtyBB) { + // If we already have an entry, and if it isn't already dirty, the block + // is done. + if (!Entry->second.isDirty()) + continue; + + // Otherwise, remember this slot so we can update the value. + ExistingResult = &Entry->second; + } + + // If the dirty entry has a pointer, start scanning from it so we don't have + // to rescan the entire block. + BasicBlock::iterator ScanPos = DirtyBB->end(); + if (ExistingResult) { + if (Instruction *Inst = ExistingResult->getInst()) { + ScanPos = Inst; + // We're removing QueryInst's use of Inst. + RemoveFromReverseMap(ReverseNonLocalDeps, Inst, + QueryCS.getInstruction()); + } + } + + // Find out if this block has a local dependency for QueryInst. + MemDepResult Dep; + + if (ScanPos != DirtyBB->begin()) { + Dep = getCallSiteDependencyFrom(QueryCS, isReadonlyCall,ScanPos, DirtyBB); + } else if (DirtyBB != &DirtyBB->getParent()->getEntryBlock()) { + // No dependence found. If this is the entry block of the function, it is + // a clobber, otherwise it is non-local. + Dep = MemDepResult::getNonLocal(); + } else { + Dep = MemDepResult::getClobber(ScanPos); + } + + // If we had a dirty entry for the block, update it. Otherwise, just add + // a new entry. + if (ExistingResult) + *ExistingResult = Dep; + else + Cache.push_back(std::make_pair(DirtyBB, Dep)); + + // If the block has a dependency (i.e. it isn't completely transparent to + // the value), remember the association! + if (!Dep.isNonLocal()) { + // Keep the ReverseNonLocalDeps map up to date so we can efficiently + // update this when we remove instructions. + if (Instruction *Inst = Dep.getInst()) + ReverseNonLocalDeps[Inst].insert(QueryCS.getInstruction()); + } else { + + // If the block *is* completely transparent to the load, we need to check + // the predecessors of this block. Add them to our worklist. + for (BasicBlock **PI = PredCache->GetPreds(DirtyBB); *PI; ++PI) + DirtyBlocks.push_back(*PI); + } + } + + return Cache; +} + +/// getNonLocalPointerDependency - Perform a full dependency query for an +/// access to the specified (non-volatile) memory location, returning the +/// set of instructions that either define or clobber the value. +/// +/// This method assumes the pointer has a "NonLocal" dependency within its +/// own block. +/// +void MemoryDependenceAnalysis:: +getNonLocalPointerDependency(Value *Pointer, bool isLoad, BasicBlock *FromBB, + SmallVectorImpl<NonLocalDepEntry> &Result) { + assert(isa<PointerType>(Pointer->getType()) && + "Can't get pointer deps of a non-pointer!"); + Result.clear(); + + // We know that the pointer value is live into FromBB find the def/clobbers + // from presecessors. + const Type *EltTy = cast<PointerType>(Pointer->getType())->getElementType(); + uint64_t PointeeSize = TD->getTypeStoreSize(EltTy); + + // This is the set of blocks we've inspected, and the pointer we consider in + // each block. Because of critical edges, we currently bail out if querying + // a block with multiple different pointers. This can happen during PHI + // translation. + DenseMap<BasicBlock*, Value*> Visited; + if (!getNonLocalPointerDepFromBB(Pointer, PointeeSize, isLoad, FromBB, + Result, Visited, true)) + return; + Result.clear(); + Result.push_back(std::make_pair(FromBB, + MemDepResult::getClobber(FromBB->begin()))); +} + +/// GetNonLocalInfoForBlock - Compute the memdep value for BB with +/// Pointer/PointeeSize using either cached information in Cache or by doing a +/// lookup (which may use dirty cache info if available). If we do a lookup, +/// add the result to the cache. +MemDepResult MemoryDependenceAnalysis:: +GetNonLocalInfoForBlock(Value *Pointer, uint64_t PointeeSize, + bool isLoad, BasicBlock *BB, + NonLocalDepInfo *Cache, unsigned NumSortedEntries) { + + // Do a binary search to see if we already have an entry for this block in + // the cache set. If so, find it. + NonLocalDepInfo::iterator Entry = + std::upper_bound(Cache->begin(), Cache->begin()+NumSortedEntries, + std::make_pair(BB, MemDepResult())); + if (Entry != Cache->begin() && prior(Entry)->first == BB) + --Entry; + + MemDepResult *ExistingResult = 0; + if (Entry != Cache->begin()+NumSortedEntries && Entry->first == BB) + ExistingResult = &Entry->second; + + // If we have a cached entry, and it is non-dirty, use it as the value for + // this dependency. + if (ExistingResult && !ExistingResult->isDirty()) { + ++NumCacheNonLocalPtr; + return *ExistingResult; + } + + // Otherwise, we have to scan for the value. If we have a dirty cache + // entry, start scanning from its position, otherwise we scan from the end + // of the block. + BasicBlock::iterator ScanPos = BB->end(); + if (ExistingResult && ExistingResult->getInst()) { + assert(ExistingResult->getInst()->getParent() == BB && + "Instruction invalidated?"); + ++NumCacheDirtyNonLocalPtr; + ScanPos = ExistingResult->getInst(); + + // Eliminating the dirty entry from 'Cache', so update the reverse info. + ValueIsLoadPair CacheKey(Pointer, isLoad); + RemoveFromReverseMap(ReverseNonLocalPtrDeps, ScanPos, CacheKey); + } else { + ++NumUncacheNonLocalPtr; + } + + // Scan the block for the dependency. + MemDepResult Dep = getPointerDependencyFrom(Pointer, PointeeSize, isLoad, + ScanPos, BB); + + // If we had a dirty entry for the block, update it. Otherwise, just add + // a new entry. + if (ExistingResult) + *ExistingResult = Dep; + else + Cache->push_back(std::make_pair(BB, Dep)); + + // If the block has a dependency (i.e. it isn't completely transparent to + // the value), remember the reverse association because we just added it + // to Cache! + if (Dep.isNonLocal()) + return Dep; + + // Keep the ReverseNonLocalPtrDeps map up to date so we can efficiently + // update MemDep when we remove instructions. + Instruction *Inst = Dep.getInst(); + assert(Inst && "Didn't depend on anything?"); + ValueIsLoadPair CacheKey(Pointer, isLoad); + ReverseNonLocalPtrDeps[Inst].insert(CacheKey); + return Dep; +} + + +/// getNonLocalPointerDepFromBB - Perform a dependency query based on +/// pointer/pointeesize starting at the end of StartBB. Add any clobber/def +/// results to the results vector and keep track of which blocks are visited in +/// 'Visited'. +/// +/// This has special behavior for the first block queries (when SkipFirstBlock +/// is true). In this special case, it ignores the contents of the specified +/// block and starts returning dependence info for its predecessors. +/// +/// This function returns false on success, or true to indicate that it could +/// not compute dependence information for some reason. This should be treated +/// as a clobber dependence on the first instruction in the predecessor block. +bool MemoryDependenceAnalysis:: +getNonLocalPointerDepFromBB(Value *Pointer, uint64_t PointeeSize, + bool isLoad, BasicBlock *StartBB, + SmallVectorImpl<NonLocalDepEntry> &Result, + DenseMap<BasicBlock*, Value*> &Visited, + bool SkipFirstBlock) { + + // Look up the cached info for Pointer. + ValueIsLoadPair CacheKey(Pointer, isLoad); + + std::pair<BBSkipFirstBlockPair, NonLocalDepInfo> *CacheInfo = + &NonLocalPointerDeps[CacheKey]; + NonLocalDepInfo *Cache = &CacheInfo->second; + + // If we have valid cached information for exactly the block we are + // investigating, just return it with no recomputation. + if (CacheInfo->first == BBSkipFirstBlockPair(StartBB, SkipFirstBlock)) { + // We have a fully cached result for this query then we can just return the + // cached results and populate the visited set. However, we have to verify + // that we don't already have conflicting results for these blocks. Check + // to ensure that if a block in the results set is in the visited set that + // it was for the same pointer query. + if (!Visited.empty()) { + for (NonLocalDepInfo::iterator I = Cache->begin(), E = Cache->end(); + I != E; ++I) { + DenseMap<BasicBlock*, Value*>::iterator VI = Visited.find(I->first); + if (VI == Visited.end() || VI->second == Pointer) continue; + + // We have a pointer mismatch in a block. Just return clobber, saying + // that something was clobbered in this result. We could also do a + // non-fully cached query, but there is little point in doing this. + return true; + } + } + + for (NonLocalDepInfo::iterator I = Cache->begin(), E = Cache->end(); + I != E; ++I) { + Visited.insert(std::make_pair(I->first, Pointer)); + if (!I->second.isNonLocal()) + Result.push_back(*I); + } + ++NumCacheCompleteNonLocalPtr; + return false; + } + + // Otherwise, either this is a new block, a block with an invalid cache + // pointer or one that we're about to invalidate by putting more info into it + // than its valid cache info. If empty, the result will be valid cache info, + // otherwise it isn't. + if (Cache->empty()) + CacheInfo->first = BBSkipFirstBlockPair(StartBB, SkipFirstBlock); + else + CacheInfo->first = BBSkipFirstBlockPair(); + + SmallVector<BasicBlock*, 32> Worklist; + Worklist.push_back(StartBB); + + // Keep track of the entries that we know are sorted. Previously cached + // entries will all be sorted. The entries we add we only sort on demand (we + // don't insert every element into its sorted position). We know that we + // won't get any reuse from currently inserted values, because we don't + // revisit blocks after we insert info for them. + unsigned NumSortedEntries = Cache->size(); + DEBUG(AssertSorted(*Cache)); + + while (!Worklist.empty()) { + BasicBlock *BB = Worklist.pop_back_val(); + + // Skip the first block if we have it. + if (!SkipFirstBlock) { + // Analyze the dependency of *Pointer in FromBB. See if we already have + // been here. + assert(Visited.count(BB) && "Should check 'visited' before adding to WL"); + + // Get the dependency info for Pointer in BB. If we have cached + // information, we will use it, otherwise we compute it. + DEBUG(AssertSorted(*Cache, NumSortedEntries)); + MemDepResult Dep = GetNonLocalInfoForBlock(Pointer, PointeeSize, isLoad, + BB, Cache, NumSortedEntries); + + // If we got a Def or Clobber, add this to the list of results. + if (!Dep.isNonLocal()) { + Result.push_back(NonLocalDepEntry(BB, Dep)); + continue; + } + } + + // If 'Pointer' is an instruction defined in this block, then we need to do + // phi translation to change it into a value live in the predecessor block. + // If phi translation fails, then we can't continue dependence analysis. + Instruction *PtrInst = dyn_cast<Instruction>(Pointer); + bool NeedsPHITranslation = PtrInst && PtrInst->getParent() == BB; + + // If no PHI translation is needed, just add all the predecessors of this + // block to scan them as well. + if (!NeedsPHITranslation) { + SkipFirstBlock = false; + for (BasicBlock **PI = PredCache->GetPreds(BB); *PI; ++PI) { + // Verify that we haven't looked at this block yet. + std::pair<DenseMap<BasicBlock*,Value*>::iterator, bool> + InsertRes = Visited.insert(std::make_pair(*PI, Pointer)); + if (InsertRes.second) { + // First time we've looked at *PI. + Worklist.push_back(*PI); + continue; + } + + // If we have seen this block before, but it was with a different + // pointer then we have a phi translation failure and we have to treat + // this as a clobber. + if (InsertRes.first->second != Pointer) + goto PredTranslationFailure; + } + continue; + } + + // If we do need to do phi translation, then there are a bunch of different + // cases, because we have to find a Value* live in the predecessor block. We + // know that PtrInst is defined in this block at least. + + // If this is directly a PHI node, just use the incoming values for each + // pred as the phi translated version. + if (PHINode *PtrPHI = dyn_cast<PHINode>(PtrInst)) { + for (BasicBlock **PI = PredCache->GetPreds(BB); *PI; ++PI) { + BasicBlock *Pred = *PI; + Value *PredPtr = PtrPHI->getIncomingValueForBlock(Pred); + + // Check to see if we have already visited this pred block with another + // pointer. If so, we can't do this lookup. This failure can occur + // with PHI translation when a critical edge exists and the PHI node in + // the successor translates to a pointer value different than the + // pointer the block was first analyzed with. + std::pair<DenseMap<BasicBlock*,Value*>::iterator, bool> + InsertRes = Visited.insert(std::make_pair(Pred, PredPtr)); + + if (!InsertRes.second) { + // If the predecessor was visited with PredPtr, then we already did + // the analysis and can ignore it. + if (InsertRes.first->second == PredPtr) + continue; + + // Otherwise, the block was previously analyzed with a different + // pointer. We can't represent the result of this case, so we just + // treat this as a phi translation failure. + goto PredTranslationFailure; + } + + // We may have added values to the cache list before this PHI + // translation. If so, we haven't done anything to ensure that the + // cache remains sorted. Sort it now (if needed) so that recursive + // invocations of getNonLocalPointerDepFromBB that could reuse the cache + // value will only see properly sorted cache arrays. + if (Cache && NumSortedEntries != Cache->size()) + std::sort(Cache->begin(), Cache->end()); + Cache = 0; + + // FIXME: it is entirely possible that PHI translating will end up with + // the same value. Consider PHI translating something like: + // X = phi [x, bb1], [y, bb2]. PHI translating for bb1 doesn't *need* + // to recurse here, pedantically speaking. + + // If we have a problem phi translating, fall through to the code below + // to handle the failure condition. + if (getNonLocalPointerDepFromBB(PredPtr, PointeeSize, isLoad, Pred, + Result, Visited)) + goto PredTranslationFailure; + } + + // Refresh the CacheInfo/Cache pointer so that it isn't invalidated. + CacheInfo = &NonLocalPointerDeps[CacheKey]; + Cache = &CacheInfo->second; + NumSortedEntries = Cache->size(); + + // Since we did phi translation, the "Cache" set won't contain all of the + // results for the query. This is ok (we can still use it to accelerate + // specific block queries) but we can't do the fastpath "return all + // results from the set" Clear out the indicator for this. + CacheInfo->first = BBSkipFirstBlockPair(); + SkipFirstBlock = false; + continue; + } + + // TODO: BITCAST, GEP. + + // cerr << "MEMDEP: Could not PHI translate: " << *Pointer; + // if (isa<BitCastInst>(PtrInst) || isa<GetElementPtrInst>(PtrInst)) + // cerr << "OP:\t\t\t\t" << *PtrInst->getOperand(0); + PredTranslationFailure: + + if (Cache == 0) { + // Refresh the CacheInfo/Cache pointer if it got invalidated. + CacheInfo = &NonLocalPointerDeps[CacheKey]; + Cache = &CacheInfo->second; + NumSortedEntries = Cache->size(); + } else if (NumSortedEntries != Cache->size()) { + std::sort(Cache->begin(), Cache->end()); + NumSortedEntries = Cache->size(); + } + + // Since we did phi translation, the "Cache" set won't contain all of the + // results for the query. This is ok (we can still use it to accelerate + // specific block queries) but we can't do the fastpath "return all + // results from the set" Clear out the indicator for this. + CacheInfo->first = BBSkipFirstBlockPair(); + + // If *nothing* works, mark the pointer as being clobbered by the first + // instruction in this block. + // + // If this is the magic first block, return this as a clobber of the whole + // incoming value. Since we can't phi translate to one of the predecessors, + // we have to bail out. + if (SkipFirstBlock) + return true; + + for (NonLocalDepInfo::reverse_iterator I = Cache->rbegin(); ; ++I) { + assert(I != Cache->rend() && "Didn't find current block??"); + if (I->first != BB) + continue; + + assert(I->second.isNonLocal() && + "Should only be here with transparent block"); + I->second = MemDepResult::getClobber(BB->begin()); + ReverseNonLocalPtrDeps[BB->begin()].insert(CacheKey); + Result.push_back(*I); + break; + } + } + + // Okay, we're done now. If we added new values to the cache, re-sort it. + switch (Cache->size()-NumSortedEntries) { + case 0: + // done, no new entries. + break; + case 2: { + // Two new entries, insert the last one into place. + NonLocalDepEntry Val = Cache->back(); + Cache->pop_back(); + NonLocalDepInfo::iterator Entry = + std::upper_bound(Cache->begin(), Cache->end()-1, Val); + Cache->insert(Entry, Val); + // FALL THROUGH. + } + case 1: + // One new entry, Just insert the new value at the appropriate position. + if (Cache->size() != 1) { + NonLocalDepEntry Val = Cache->back(); + Cache->pop_back(); + NonLocalDepInfo::iterator Entry = + std::upper_bound(Cache->begin(), Cache->end(), Val); + Cache->insert(Entry, Val); + } + break; + default: + // Added many values, do a full scale sort. + std::sort(Cache->begin(), Cache->end()); + } + DEBUG(AssertSorted(*Cache)); + return false; +} + +/// RemoveCachedNonLocalPointerDependencies - If P exists in +/// CachedNonLocalPointerInfo, remove it. +void MemoryDependenceAnalysis:: +RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair P) { + CachedNonLocalPointerInfo::iterator It = + NonLocalPointerDeps.find(P); + if (It == NonLocalPointerDeps.end()) return; + + // Remove all of the entries in the BB->val map. This involves removing + // instructions from the reverse map. + NonLocalDepInfo &PInfo = It->second.second; + + for (unsigned i = 0, e = PInfo.size(); i != e; ++i) { + Instruction *Target = PInfo[i].second.getInst(); + if (Target == 0) continue; // Ignore non-local dep results. + assert(Target->getParent() == PInfo[i].first); + + // Eliminating the dirty entry from 'Cache', so update the reverse info. + RemoveFromReverseMap(ReverseNonLocalPtrDeps, Target, P); + } + + // Remove P from NonLocalPointerDeps (which deletes NonLocalDepInfo). + NonLocalPointerDeps.erase(It); +} + + +/// invalidateCachedPointerInfo - This method is used to invalidate cached +/// information about the specified pointer, because it may be too +/// conservative in memdep. This is an optional call that can be used when +/// the client detects an equivalence between the pointer and some other +/// value and replaces the other value with ptr. This can make Ptr available +/// in more places that cached info does not necessarily keep. +void MemoryDependenceAnalysis::invalidateCachedPointerInfo(Value *Ptr) { + // If Ptr isn't really a pointer, just ignore it. + if (!isa<PointerType>(Ptr->getType())) return; + // Flush store info for the pointer. + RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(Ptr, false)); + // Flush load info for the pointer. + RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(Ptr, true)); +} + +/// removeInstruction - Remove an instruction from the dependence analysis, +/// updating the dependence of instructions that previously depended on it. +/// This method attempts to keep the cache coherent using the reverse map. +void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { + // Walk through the Non-local dependencies, removing this one as the value + // for any cached queries. + NonLocalDepMapType::iterator NLDI = NonLocalDeps.find(RemInst); + if (NLDI != NonLocalDeps.end()) { + NonLocalDepInfo &BlockMap = NLDI->second.first; + for (NonLocalDepInfo::iterator DI = BlockMap.begin(), DE = BlockMap.end(); + DI != DE; ++DI) + if (Instruction *Inst = DI->second.getInst()) + RemoveFromReverseMap(ReverseNonLocalDeps, Inst, RemInst); + NonLocalDeps.erase(NLDI); + } + + // If we have a cached local dependence query for this instruction, remove it. + // + LocalDepMapType::iterator LocalDepEntry = LocalDeps.find(RemInst); + if (LocalDepEntry != LocalDeps.end()) { + // Remove us from DepInst's reverse set now that the local dep info is gone. + if (Instruction *Inst = LocalDepEntry->second.getInst()) + RemoveFromReverseMap(ReverseLocalDeps, Inst, RemInst); + + // Remove this local dependency info. + LocalDeps.erase(LocalDepEntry); + } + + // If we have any cached pointer dependencies on this instruction, remove + // them. If the instruction has non-pointer type, then it can't be a pointer + // base. + + // Remove it from both the load info and the store info. The instruction + // can't be in either of these maps if it is non-pointer. + if (isa<PointerType>(RemInst->getType())) { + RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(RemInst, false)); + RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(RemInst, true)); + } + + // Loop over all of the things that depend on the instruction we're removing. + // + SmallVector<std::pair<Instruction*, Instruction*>, 8> ReverseDepsToAdd; + + // If we find RemInst as a clobber or Def in any of the maps for other values, + // we need to replace its entry with a dirty version of the instruction after + // it. If RemInst is a terminator, we use a null dirty value. + // + // Using a dirty version of the instruction after RemInst saves having to scan + // the entire block to get to this point. + MemDepResult NewDirtyVal; + if (!RemInst->isTerminator()) + NewDirtyVal = MemDepResult::getDirty(++BasicBlock::iterator(RemInst)); + + ReverseDepMapType::iterator ReverseDepIt = ReverseLocalDeps.find(RemInst); + if (ReverseDepIt != ReverseLocalDeps.end()) { + SmallPtrSet<Instruction*, 4> &ReverseDeps = ReverseDepIt->second; + // RemInst can't be the terminator if it has local stuff depending on it. + assert(!ReverseDeps.empty() && !isa<TerminatorInst>(RemInst) && + "Nothing can locally depend on a terminator"); + + for (SmallPtrSet<Instruction*, 4>::iterator I = ReverseDeps.begin(), + E = ReverseDeps.end(); I != E; ++I) { + Instruction *InstDependingOnRemInst = *I; + assert(InstDependingOnRemInst != RemInst && + "Already removed our local dep info"); + + LocalDeps[InstDependingOnRemInst] = NewDirtyVal; + + // Make sure to remember that new things depend on NewDepInst. + assert(NewDirtyVal.getInst() && "There is no way something else can have " + "a local dep on this if it is a terminator!"); + ReverseDepsToAdd.push_back(std::make_pair(NewDirtyVal.getInst(), + InstDependingOnRemInst)); + } + + ReverseLocalDeps.erase(ReverseDepIt); + + // Add new reverse deps after scanning the set, to avoid invalidating the + // 'ReverseDeps' reference. + while (!ReverseDepsToAdd.empty()) { + ReverseLocalDeps[ReverseDepsToAdd.back().first] + .insert(ReverseDepsToAdd.back().second); + ReverseDepsToAdd.pop_back(); + } + } + + ReverseDepIt = ReverseNonLocalDeps.find(RemInst); + if (ReverseDepIt != ReverseNonLocalDeps.end()) { + SmallPtrSet<Instruction*, 4> &Set = ReverseDepIt->second; + for (SmallPtrSet<Instruction*, 4>::iterator I = Set.begin(), E = Set.end(); + I != E; ++I) { + assert(*I != RemInst && "Already removed NonLocalDep info for RemInst"); + + PerInstNLInfo &INLD = NonLocalDeps[*I]; + // The information is now dirty! + INLD.second = true; + + for (NonLocalDepInfo::iterator DI = INLD.first.begin(), + DE = INLD.first.end(); DI != DE; ++DI) { + if (DI->second.getInst() != RemInst) continue; + + // Convert to a dirty entry for the subsequent instruction. + DI->second = NewDirtyVal; + + if (Instruction *NextI = NewDirtyVal.getInst()) + ReverseDepsToAdd.push_back(std::make_pair(NextI, *I)); + } + } + + ReverseNonLocalDeps.erase(ReverseDepIt); + + // Add new reverse deps after scanning the set, to avoid invalidating 'Set' + while (!ReverseDepsToAdd.empty()) { + ReverseNonLocalDeps[ReverseDepsToAdd.back().first] + .insert(ReverseDepsToAdd.back().second); + ReverseDepsToAdd.pop_back(); + } + } + + // If the instruction is in ReverseNonLocalPtrDeps then it appears as a + // value in the NonLocalPointerDeps info. + ReverseNonLocalPtrDepTy::iterator ReversePtrDepIt = + ReverseNonLocalPtrDeps.find(RemInst); + if (ReversePtrDepIt != ReverseNonLocalPtrDeps.end()) { + SmallPtrSet<ValueIsLoadPair, 4> &Set = ReversePtrDepIt->second; + SmallVector<std::pair<Instruction*, ValueIsLoadPair>,8> ReversePtrDepsToAdd; + + for (SmallPtrSet<ValueIsLoadPair, 4>::iterator I = Set.begin(), + E = Set.end(); I != E; ++I) { + ValueIsLoadPair P = *I; + assert(P.getPointer() != RemInst && + "Already removed NonLocalPointerDeps info for RemInst"); + + NonLocalDepInfo &NLPDI = NonLocalPointerDeps[P].second; + + // The cache is not valid for any specific block anymore. + NonLocalPointerDeps[P].first = BBSkipFirstBlockPair(); + + // Update any entries for RemInst to use the instruction after it. + for (NonLocalDepInfo::iterator DI = NLPDI.begin(), DE = NLPDI.end(); + DI != DE; ++DI) { + if (DI->second.getInst() != RemInst) continue; + + // Convert to a dirty entry for the subsequent instruction. + DI->second = NewDirtyVal; + + if (Instruction *NewDirtyInst = NewDirtyVal.getInst()) + ReversePtrDepsToAdd.push_back(std::make_pair(NewDirtyInst, P)); + } + + // Re-sort the NonLocalDepInfo. Changing the dirty entry to its + // subsequent value may invalidate the sortedness. + std::sort(NLPDI.begin(), NLPDI.end()); + } + + ReverseNonLocalPtrDeps.erase(ReversePtrDepIt); + + while (!ReversePtrDepsToAdd.empty()) { + ReverseNonLocalPtrDeps[ReversePtrDepsToAdd.back().first] + .insert(ReversePtrDepsToAdd.back().second); + ReversePtrDepsToAdd.pop_back(); + } + } + + + assert(!NonLocalDeps.count(RemInst) && "RemInst got reinserted?"); + AA->deleteValue(RemInst); + DEBUG(verifyRemoved(RemInst)); +} +/// verifyRemoved - Verify that the specified instruction does not occur +/// in our internal data structures. +void MemoryDependenceAnalysis::verifyRemoved(Instruction *D) const { + for (LocalDepMapType::const_iterator I = LocalDeps.begin(), + E = LocalDeps.end(); I != E; ++I) { + assert(I->first != D && "Inst occurs in data structures"); + assert(I->second.getInst() != D && + "Inst occurs in data structures"); + } + + for (CachedNonLocalPointerInfo::const_iterator I =NonLocalPointerDeps.begin(), + E = NonLocalPointerDeps.end(); I != E; ++I) { + assert(I->first.getPointer() != D && "Inst occurs in NLPD map key"); + const NonLocalDepInfo &Val = I->second.second; + for (NonLocalDepInfo::const_iterator II = Val.begin(), E = Val.end(); + II != E; ++II) + assert(II->second.getInst() != D && "Inst occurs as NLPD value"); + } + + for (NonLocalDepMapType::const_iterator I = NonLocalDeps.begin(), + E = NonLocalDeps.end(); I != E; ++I) { + assert(I->first != D && "Inst occurs in data structures"); + const PerInstNLInfo &INLD = I->second; + for (NonLocalDepInfo::const_iterator II = INLD.first.begin(), + EE = INLD.first.end(); II != EE; ++II) + assert(II->second.getInst() != D && "Inst occurs in data structures"); + } + + for (ReverseDepMapType::const_iterator I = ReverseLocalDeps.begin(), + E = ReverseLocalDeps.end(); I != E; ++I) { + assert(I->first != D && "Inst occurs in data structures"); + for (SmallPtrSet<Instruction*, 4>::const_iterator II = I->second.begin(), + EE = I->second.end(); II != EE; ++II) + assert(*II != D && "Inst occurs in data structures"); + } + + for (ReverseDepMapType::const_iterator I = ReverseNonLocalDeps.begin(), + E = ReverseNonLocalDeps.end(); + I != E; ++I) { + assert(I->first != D && "Inst occurs in data structures"); + for (SmallPtrSet<Instruction*, 4>::const_iterator II = I->second.begin(), + EE = I->second.end(); II != EE; ++II) + assert(*II != D && "Inst occurs in data structures"); + } + + for (ReverseNonLocalPtrDepTy::const_iterator + I = ReverseNonLocalPtrDeps.begin(), + E = ReverseNonLocalPtrDeps.end(); I != E; ++I) { + assert(I->first != D && "Inst occurs in rev NLPD map"); + + for (SmallPtrSet<ValueIsLoadPair, 4>::const_iterator II = I->second.begin(), + E = I->second.end(); II != E; ++II) + assert(*II != ValueIsLoadPair(D, false) && + *II != ValueIsLoadPair(D, true) && + "Inst occurs in ReverseNonLocalPtrDeps map"); + } + +} diff --git a/lib/Analysis/PostDominators.cpp b/lib/Analysis/PostDominators.cpp new file mode 100644 index 0000000..4853c2a --- /dev/null +++ b/lib/Analysis/PostDominators.cpp @@ -0,0 +1,94 @@ +//===- PostDominators.cpp - Post-Dominator Calculation --------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the post-dominator construction algorithms. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "postdomtree" + +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Instructions.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Debug.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/Analysis/DominatorInternals.h" +using namespace llvm; + +//===----------------------------------------------------------------------===// +// PostDominatorTree Implementation +//===----------------------------------------------------------------------===// + +char PostDominatorTree::ID = 0; +char PostDominanceFrontier::ID = 0; +static RegisterPass<PostDominatorTree> +F("postdomtree", "Post-Dominator Tree Construction", true, true); + +bool PostDominatorTree::runOnFunction(Function &F) { + DT->recalculate(F); + DEBUG(DT->dump()); + return false; +} + +PostDominatorTree::~PostDominatorTree() +{ + delete DT; +} + +FunctionPass* llvm::createPostDomTree() { + return new PostDominatorTree(); +} + +//===----------------------------------------------------------------------===// +// PostDominanceFrontier Implementation +//===----------------------------------------------------------------------===// + +static RegisterPass<PostDominanceFrontier> +H("postdomfrontier", "Post-Dominance Frontier Construction", true, true); + +const DominanceFrontier::DomSetType & +PostDominanceFrontier::calculate(const PostDominatorTree &DT, + const DomTreeNode *Node) { + // Loop over CFG successors to calculate DFlocal[Node] + BasicBlock *BB = Node->getBlock(); + DomSetType &S = Frontiers[BB]; // The new set to fill in... + if (getRoots().empty()) return S; + + if (BB) + for (pred_iterator SI = pred_begin(BB), SE = pred_end(BB); + SI != SE; ++SI) { + // Does Node immediately dominate this predecessor? + DomTreeNode *SINode = DT[*SI]; + if (SINode && SINode->getIDom() != Node) + S.insert(*SI); + } + + // At this point, S is DFlocal. Now we union in DFup's of our children... + // Loop through and visit the nodes that Node immediately dominates (Node's + // children in the IDomTree) + // + for (DomTreeNode::const_iterator + NI = Node->begin(), NE = Node->end(); NI != NE; ++NI) { + DomTreeNode *IDominee = *NI; + const DomSetType &ChildDF = calculate(DT, IDominee); + + DomSetType::const_iterator CDFI = ChildDF.begin(), CDFE = ChildDF.end(); + for (; CDFI != CDFE; ++CDFI) { + if (!DT.properlyDominates(Node, DT[*CDFI])) + S.insert(*CDFI); + } + } + + return S; +} + +FunctionPass* llvm::createPostDomFrontier() { + return new PostDominanceFrontier(); +} diff --git a/lib/Analysis/ProfileInfo.cpp b/lib/Analysis/ProfileInfo.cpp new file mode 100644 index 0000000..a0965b6 --- /dev/null +++ b/lib/Analysis/ProfileInfo.cpp @@ -0,0 +1,100 @@ +//===- ProfileInfo.cpp - Profile Info Interface ---------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the abstract ProfileInfo interface, and the default +// "no profile" implementation. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/ProfileInfo.h" +#include "llvm/Pass.h" +#include "llvm/Support/CFG.h" +#include "llvm/Support/Compiler.h" +#include <set> +using namespace llvm; + +// Register the ProfileInfo interface, providing a nice name to refer to. +static RegisterAnalysisGroup<ProfileInfo> Z("Profile Information"); +char ProfileInfo::ID = 0; + +ProfileInfo::~ProfileInfo() {} + +unsigned ProfileInfo::getExecutionCount(BasicBlock *BB) const { + pred_iterator PI = pred_begin(BB), PE = pred_end(BB); + + // Are there zero predecessors of this block? + if (PI == PE) { + // If this is the entry block, look for the Null -> Entry edge. + if (BB == &BB->getParent()->getEntryBlock()) + return getEdgeWeight(0, BB); + else + return 0; // Otherwise, this is a dead block. + } + + // Otherwise, if there are predecessors, the execution count of this block is + // the sum of the edge frequencies from the incoming edges. Note that if + // there are multiple edges from a predecessor to this block that we don't + // want to count its weight multiple times. For this reason, we keep track of + // the predecessors we've seen and only count them if we haven't run into them + // yet. + // + // We don't want to create an std::set unless we are dealing with a block that + // has a LARGE number of in-edges. Handle the common case of having only a + // few in-edges with special code. + // + BasicBlock *FirstPred = *PI; + unsigned Count = getEdgeWeight(FirstPred, BB); + ++PI; + if (PI == PE) return Count; // Quick exit for single predecessor blocks + + BasicBlock *SecondPred = *PI; + if (SecondPred != FirstPred) Count += getEdgeWeight(SecondPred, BB); + ++PI; + if (PI == PE) return Count; // Quick exit for two predecessor blocks + + BasicBlock *ThirdPred = *PI; + if (ThirdPred != FirstPred && ThirdPred != SecondPred) + Count += getEdgeWeight(ThirdPred, BB); + ++PI; + if (PI == PE) return Count; // Quick exit for three predecessor blocks + + std::set<BasicBlock*> ProcessedPreds; + ProcessedPreds.insert(FirstPred); + ProcessedPreds.insert(SecondPred); + ProcessedPreds.insert(ThirdPred); + for (; PI != PE; ++PI) + if (ProcessedPreds.insert(*PI).second) + Count += getEdgeWeight(*PI, BB); + return Count; +} + + + +//===----------------------------------------------------------------------===// +// NoProfile ProfileInfo implementation +// + +namespace { + struct VISIBILITY_HIDDEN NoProfileInfo + : public ImmutablePass, public ProfileInfo { + static char ID; // Class identification, replacement for typeinfo + NoProfileInfo() : ImmutablePass(&ID) {} + }; +} // End of anonymous namespace + +char NoProfileInfo::ID = 0; +// Register this pass... +static RegisterPass<NoProfileInfo> +X("no-profile", "No Profile Information", false, true); + +// Declare that we implement the ProfileInfo interface +static RegisterAnalysisGroup<ProfileInfo, true> Y(X); + +ImmutablePass *llvm::createNoProfileInfoPass() { return new NoProfileInfo(); } diff --git a/lib/Analysis/ProfileInfoLoader.cpp b/lib/Analysis/ProfileInfoLoader.cpp new file mode 100644 index 0000000..3a0a740 --- /dev/null +++ b/lib/Analysis/ProfileInfoLoader.cpp @@ -0,0 +1,277 @@ +//===- ProfileInfoLoad.cpp - Load profile information from disk -----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// The ProfileInfoLoader class is used to load and represent profiling +// information read in from the dump file. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ProfileInfoLoader.h" +#include "llvm/Analysis/ProfileInfoTypes.h" +#include "llvm/Module.h" +#include "llvm/InstrTypes.h" +#include "llvm/Support/Streams.h" +#include <cstdio> +#include <cstdlib> +#include <map> +using namespace llvm; + +// ByteSwap - Byteswap 'Var' if 'Really' is true. +// +static inline unsigned ByteSwap(unsigned Var, bool Really) { + if (!Really) return Var; + return ((Var & (255<< 0)) << 24) | + ((Var & (255<< 8)) << 8) | + ((Var & (255<<16)) >> 8) | + ((Var & (255<<24)) >> 24); +} + +static void ReadProfilingBlock(const char *ToolName, FILE *F, + bool ShouldByteSwap, + std::vector<unsigned> &Data) { + // Read the number of entries... + unsigned NumEntries; + if (fread(&NumEntries, sizeof(unsigned), 1, F) != 1) { + cerr << ToolName << ": data packet truncated!\n"; + perror(0); + exit(1); + } + NumEntries = ByteSwap(NumEntries, ShouldByteSwap); + + // Read the counts... + std::vector<unsigned> TempSpace(NumEntries); + + // Read in the block of data... + if (fread(&TempSpace[0], sizeof(unsigned)*NumEntries, 1, F) != 1) { + cerr << ToolName << ": data packet truncated!\n"; + perror(0); + exit(1); + } + + // Make sure we have enough space... + if (Data.size() < NumEntries) + Data.resize(NumEntries); + + // Accumulate the data we just read into the data. + if (!ShouldByteSwap) { + for (unsigned i = 0; i != NumEntries; ++i) + Data[i] += TempSpace[i]; + } else { + for (unsigned i = 0; i != NumEntries; ++i) + Data[i] += ByteSwap(TempSpace[i], true); + } +} + +// ProfileInfoLoader ctor - Read the specified profiling data file, exiting the +// program if the file is invalid or broken. +// +ProfileInfoLoader::ProfileInfoLoader(const char *ToolName, + const std::string &Filename, + Module &TheModule) : M(TheModule) { + FILE *F = fopen(Filename.c_str(), "r"); + if (F == 0) { + cerr << ToolName << ": Error opening '" << Filename << "': "; + perror(0); + exit(1); + } + + // Keep reading packets until we run out of them. + unsigned PacketType; + while (fread(&PacketType, sizeof(unsigned), 1, F) == 1) { + // If the low eight bits of the packet are zero, we must be dealing with an + // endianness mismatch. Byteswap all words read from the profiling + // information. + bool ShouldByteSwap = (char)PacketType == 0; + PacketType = ByteSwap(PacketType, ShouldByteSwap); + + switch (PacketType) { + case ArgumentInfo: { + unsigned ArgLength; + if (fread(&ArgLength, sizeof(unsigned), 1, F) != 1) { + cerr << ToolName << ": arguments packet truncated!\n"; + perror(0); + exit(1); + } + ArgLength = ByteSwap(ArgLength, ShouldByteSwap); + + // Read in the arguments... + std::vector<char> Chars(ArgLength+4); + + if (ArgLength) + if (fread(&Chars[0], (ArgLength+3) & ~3, 1, F) != 1) { + cerr << ToolName << ": arguments packet truncated!\n"; + perror(0); + exit(1); + } + CommandLines.push_back(std::string(&Chars[0], &Chars[ArgLength])); + break; + } + + case FunctionInfo: + ReadProfilingBlock(ToolName, F, ShouldByteSwap, FunctionCounts); + break; + + case BlockInfo: + ReadProfilingBlock(ToolName, F, ShouldByteSwap, BlockCounts); + break; + + case EdgeInfo: + ReadProfilingBlock(ToolName, F, ShouldByteSwap, EdgeCounts); + break; + + case BBTraceInfo: + ReadProfilingBlock(ToolName, F, ShouldByteSwap, BBTrace); + break; + + default: + cerr << ToolName << ": Unknown packet type #" << PacketType << "!\n"; + exit(1); + } + } + + fclose(F); +} + + +// getFunctionCounts - This method is used by consumers of function counting +// information. If we do not directly have function count information, we +// compute it from other, more refined, types of profile information. +// +void ProfileInfoLoader::getFunctionCounts(std::vector<std::pair<Function*, + unsigned> > &Counts) { + if (FunctionCounts.empty()) { + if (hasAccurateBlockCounts()) { + // Synthesize function frequency information from the number of times + // their entry blocks were executed. + std::vector<std::pair<BasicBlock*, unsigned> > BlockCounts; + getBlockCounts(BlockCounts); + + for (unsigned i = 0, e = BlockCounts.size(); i != e; ++i) + if (&BlockCounts[i].first->getParent()->getEntryBlock() == + BlockCounts[i].first) + Counts.push_back(std::make_pair(BlockCounts[i].first->getParent(), + BlockCounts[i].second)); + } else { + cerr << "Function counts are not available!\n"; + } + return; + } + + unsigned Counter = 0; + for (Module::iterator I = M.begin(), E = M.end(); + I != E && Counter != FunctionCounts.size(); ++I) + if (!I->isDeclaration()) + Counts.push_back(std::make_pair(I, FunctionCounts[Counter++])); +} + +// getBlockCounts - This method is used by consumers of block counting +// information. If we do not directly have block count information, we +// compute it from other, more refined, types of profile information. +// +void ProfileInfoLoader::getBlockCounts(std::vector<std::pair<BasicBlock*, + unsigned> > &Counts) { + if (BlockCounts.empty()) { + if (hasAccurateEdgeCounts()) { + // Synthesize block count information from edge frequency information. + // The block execution frequency is equal to the sum of the execution + // frequency of all outgoing edges from a block. + // + // If a block has no successors, this will not be correct, so we have to + // special case it. :( + std::vector<std::pair<Edge, unsigned> > EdgeCounts; + getEdgeCounts(EdgeCounts); + + std::map<BasicBlock*, unsigned> InEdgeFreqs; + + BasicBlock *LastBlock = 0; + TerminatorInst *TI = 0; + for (unsigned i = 0, e = EdgeCounts.size(); i != e; ++i) { + if (EdgeCounts[i].first.first != LastBlock) { + LastBlock = EdgeCounts[i].first.first; + TI = LastBlock->getTerminator(); + Counts.push_back(std::make_pair(LastBlock, 0)); + } + Counts.back().second += EdgeCounts[i].second; + unsigned SuccNum = EdgeCounts[i].first.second; + if (SuccNum >= TI->getNumSuccessors()) { + static bool Warned = false; + if (!Warned) { + cerr << "WARNING: profile info doesn't seem to match" + << " the program!\n"; + Warned = true; + } + } else { + // If this successor has no successors of its own, we will never + // compute an execution count for that block. Remember the incoming + // edge frequencies to add later. + BasicBlock *Succ = TI->getSuccessor(SuccNum); + if (Succ->getTerminator()->getNumSuccessors() == 0) + InEdgeFreqs[Succ] += EdgeCounts[i].second; + } + } + + // Now we have to accumulate information for those blocks without + // successors into our table. + for (std::map<BasicBlock*, unsigned>::iterator I = InEdgeFreqs.begin(), + E = InEdgeFreqs.end(); I != E; ++I) { + unsigned i = 0; + for (; i != Counts.size() && Counts[i].first != I->first; ++i) + /*empty*/; + if (i == Counts.size()) Counts.push_back(std::make_pair(I->first, 0)); + Counts[i].second += I->second; + } + + } else { + cerr << "Block counts are not available!\n"; + } + return; + } + + unsigned Counter = 0; + for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) { + Counts.push_back(std::make_pair(BB, BlockCounts[Counter++])); + if (Counter == BlockCounts.size()) + return; + } +} + +// getEdgeCounts - This method is used by consumers of edge counting +// information. If we do not directly have edge count information, we compute +// it from other, more refined, types of profile information. +// +void ProfileInfoLoader::getEdgeCounts(std::vector<std::pair<Edge, + unsigned> > &Counts) { + if (EdgeCounts.empty()) { + cerr << "Edge counts not available, and no synthesis " + << "is implemented yet!\n"; + return; + } + + unsigned Counter = 0; + for (Module::iterator F = M.begin(), E = M.end(); F != E; ++F) + for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) + for (unsigned i = 0, e = BB->getTerminator()->getNumSuccessors(); + i != e; ++i) { + Counts.push_back(std::make_pair(Edge(BB, i), EdgeCounts[Counter++])); + if (Counter == EdgeCounts.size()) + return; + } +} + +// getBBTrace - This method is used by consumers of basic-block trace +// information. +// +void ProfileInfoLoader::getBBTrace(std::vector<BasicBlock *> &Trace) { + if (BBTrace.empty ()) { + cerr << "Basic block trace is not available!\n"; + return; + } + cerr << "Basic block trace loading is not implemented yet!\n"; +} diff --git a/lib/Analysis/ProfileInfoLoaderPass.cpp b/lib/Analysis/ProfileInfoLoaderPass.cpp new file mode 100644 index 0000000..0a8a87b --- /dev/null +++ b/lib/Analysis/ProfileInfoLoaderPass.cpp @@ -0,0 +1,92 @@ +//===- ProfileInfoLoaderPass.cpp - LLVM Pass to load profile info ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a concrete implementation of profiling information that +// loads the information from a profile dump file. +// +//===----------------------------------------------------------------------===// + +#include "llvm/BasicBlock.h" +#include "llvm/InstrTypes.h" +#include "llvm/Pass.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/ProfileInfo.h" +#include "llvm/Analysis/ProfileInfoLoader.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Streams.h" +using namespace llvm; + +static cl::opt<std::string> +ProfileInfoFilename("profile-info-file", cl::init("llvmprof.out"), + cl::value_desc("filename"), + cl::desc("Profile file loaded by -profile-loader")); + +namespace { + class VISIBILITY_HIDDEN LoaderPass : public ModulePass, public ProfileInfo { + std::string Filename; + public: + static char ID; // Class identification, replacement for typeinfo + explicit LoaderPass(const std::string &filename = "") + : ModulePass(&ID), Filename(filename) { + if (filename.empty()) Filename = ProfileInfoFilename; + } + + virtual void getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + } + + virtual const char *getPassName() const { + return "Profiling information loader"; + } + + /// run - Load the profile information from the specified file. + virtual bool runOnModule(Module &M); + }; +} // End of anonymous namespace + +char LoaderPass::ID = 0; +static RegisterPass<LoaderPass> +X("profile-loader", "Load profile information from llvmprof.out", false, true); + +static RegisterAnalysisGroup<ProfileInfo> Y(X); + +ModulePass *llvm::createProfileLoaderPass() { return new LoaderPass(); } + +/// createProfileLoaderPass - This function returns a Pass that loads the +/// profiling information for the module from the specified filename, making it +/// available to the optimizers. +Pass *llvm::createProfileLoaderPass(const std::string &Filename) { + return new LoaderPass(Filename); +} + +bool LoaderPass::runOnModule(Module &M) { + ProfileInfoLoader PIL("profile-loader", Filename, M); + EdgeCounts.clear(); + bool PrintedWarning = false; + + std::vector<std::pair<ProfileInfoLoader::Edge, unsigned> > ECs; + PIL.getEdgeCounts(ECs); + for (unsigned i = 0, e = ECs.size(); i != e; ++i) { + BasicBlock *BB = ECs[i].first.first; + unsigned SuccNum = ECs[i].first.second; + TerminatorInst *TI = BB->getTerminator(); + if (SuccNum >= TI->getNumSuccessors()) { + if (!PrintedWarning) { + cerr << "WARNING: profile information is inconsistent with " + << "the current program!\n"; + PrintedWarning = true; + } + } else { + EdgeCounts[std::make_pair(BB, TI->getSuccessor(SuccNum))]+= ECs[i].second; + } + } + + return false; +} diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp new file mode 100644 index 0000000..f7f1849 --- /dev/null +++ b/lib/Analysis/ScalarEvolution.cpp @@ -0,0 +1,3824 @@ +//===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains the implementation of the scalar evolution analysis +// engine, which is used primarily to analyze expressions involving induction +// variables in loops. +// +// There are several aspects to this library. First is the representation of +// scalar expressions, which are represented as subclasses of the SCEV class. +// These classes are used to represent certain types of subexpressions that we +// can handle. These classes are reference counted, managed by the SCEVHandle +// class. We only create one SCEV of a particular shape, so pointer-comparisons +// for equality are legal. +// +// One important aspect of the SCEV objects is that they are never cyclic, even +// if there is a cycle in the dataflow for an expression (ie, a PHI node). If +// the PHI node is one of the idioms that we can represent (e.g., a polynomial +// recurrence) then we represent it directly as a recurrence node, otherwise we +// represent it as a SCEVUnknown node. +// +// In addition to being able to represent expressions of various types, we also +// have folders that are used to build the *canonical* representation for a +// particular expression. These folders are capable of using a variety of +// rewrite rules to simplify the expressions. +// +// Once the folders are defined, we can implement the more interesting +// higher-level code, such as the code that recognizes PHI nodes of various +// types, computes the execution count of a loop, etc. +// +// TODO: We should use these routines and value representations to implement +// dependence analysis! +// +//===----------------------------------------------------------------------===// +// +// There are several good references for the techniques used in this analysis. +// +// Chains of recurrences -- a method to expedite the evaluation +// of closed-form functions +// Olaf Bachmann, Paul S. Wang, Eugene V. Zima +// +// On computational properties of chains of recurrences +// Eugene V. Zima +// +// Symbolic Evaluation of Chains of Recurrences for Loop Optimization +// Robert A. van Engelen +// +// Efficient Symbolic Analysis for Optimizing Compilers +// Robert A. van Engelen +// +// Using the chains of recurrences algebra for data dependence testing and +// induction variable substitution +// MS Thesis, Johnie Birch +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "scalar-evolution" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Constants.h" +#include "llvm/DerivedTypes.h" +#include "llvm/GlobalVariable.h" +#include "llvm/Instructions.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/Dominators.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/ConstantRange.h" +#include "llvm/Support/GetElementPtrTypeIterator.h" +#include "llvm/Support/InstIterator.h" +#include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include <ostream> +#include <algorithm> +using namespace llvm; + +STATISTIC(NumArrayLenItCounts, + "Number of trip counts computed with array length"); +STATISTIC(NumTripCountsComputed, + "Number of loops with predictable loop counts"); +STATISTIC(NumTripCountsNotComputed, + "Number of loops without predictable loop counts"); +STATISTIC(NumBruteForceTripCountsComputed, + "Number of loops with trip counts computed by force"); + +static cl::opt<unsigned> +MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, + cl::desc("Maximum number of iterations SCEV will " + "symbolically execute a constant derived loop"), + cl::init(100)); + +static RegisterPass<ScalarEvolution> +R("scalar-evolution", "Scalar Evolution Analysis", false, true); +char ScalarEvolution::ID = 0; + +//===----------------------------------------------------------------------===// +// SCEV class definitions +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Implementation of the SCEV class. +// +SCEV::~SCEV() {} +void SCEV::dump() const { + print(errs()); + errs() << '\n'; +} + +void SCEV::print(std::ostream &o) const { + raw_os_ostream OS(o); + print(OS); +} + +bool SCEV::isZero() const { + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this)) + return SC->getValue()->isZero(); + return false; +} + +bool SCEV::isOne() const { + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this)) + return SC->getValue()->isOne(); + return false; +} + +SCEVCouldNotCompute::SCEVCouldNotCompute() : SCEV(scCouldNotCompute) {} +SCEVCouldNotCompute::~SCEVCouldNotCompute() {} + +bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const { + assert(0 && "Attempt to use a SCEVCouldNotCompute object!"); + return false; +} + +const Type *SCEVCouldNotCompute::getType() const { + assert(0 && "Attempt to use a SCEVCouldNotCompute object!"); + return 0; +} + +bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const { + assert(0 && "Attempt to use a SCEVCouldNotCompute object!"); + return false; +} + +SCEVHandle SCEVCouldNotCompute:: +replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, + const SCEVHandle &Conc, + ScalarEvolution &SE) const { + return this; +} + +void SCEVCouldNotCompute::print(raw_ostream &OS) const { + OS << "***COULDNOTCOMPUTE***"; +} + +bool SCEVCouldNotCompute::classof(const SCEV *S) { + return S->getSCEVType() == scCouldNotCompute; +} + + +// SCEVConstants - Only allow the creation of one SCEVConstant for any +// particular value. Don't use a SCEVHandle here, or else the object will +// never be deleted! +static ManagedStatic<std::map<ConstantInt*, SCEVConstant*> > SCEVConstants; + + +SCEVConstant::~SCEVConstant() { + SCEVConstants->erase(V); +} + +SCEVHandle ScalarEvolution::getConstant(ConstantInt *V) { + SCEVConstant *&R = (*SCEVConstants)[V]; + if (R == 0) R = new SCEVConstant(V); + return R; +} + +SCEVHandle ScalarEvolution::getConstant(const APInt& Val) { + return getConstant(ConstantInt::get(Val)); +} + +const Type *SCEVConstant::getType() const { return V->getType(); } + +void SCEVConstant::print(raw_ostream &OS) const { + WriteAsOperand(OS, V, false); +} + +SCEVCastExpr::SCEVCastExpr(unsigned SCEVTy, + const SCEVHandle &op, const Type *ty) + : SCEV(SCEVTy), Op(op), Ty(ty) {} + +SCEVCastExpr::~SCEVCastExpr() {} + +bool SCEVCastExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { + return Op->dominates(BB, DT); +} + +// SCEVTruncates - Only allow the creation of one SCEVTruncateExpr for any +// particular input. Don't use a SCEVHandle here, or else the object will +// never be deleted! +static ManagedStatic<std::map<std::pair<const SCEV*, const Type*>, + SCEVTruncateExpr*> > SCEVTruncates; + +SCEVTruncateExpr::SCEVTruncateExpr(const SCEVHandle &op, const Type *ty) + : SCEVCastExpr(scTruncate, op, ty) { + assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) && + (Ty->isInteger() || isa<PointerType>(Ty)) && + "Cannot truncate non-integer value!"); +} + +SCEVTruncateExpr::~SCEVTruncateExpr() { + SCEVTruncates->erase(std::make_pair(Op, Ty)); +} + +void SCEVTruncateExpr::print(raw_ostream &OS) const { + OS << "(trunc " << *Op->getType() << " " << *Op << " to " << *Ty << ")"; +} + +// SCEVZeroExtends - Only allow the creation of one SCEVZeroExtendExpr for any +// particular input. Don't use a SCEVHandle here, or else the object will never +// be deleted! +static ManagedStatic<std::map<std::pair<const SCEV*, const Type*>, + SCEVZeroExtendExpr*> > SCEVZeroExtends; + +SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEVHandle &op, const Type *ty) + : SCEVCastExpr(scZeroExtend, op, ty) { + assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) && + (Ty->isInteger() || isa<PointerType>(Ty)) && + "Cannot zero extend non-integer value!"); +} + +SCEVZeroExtendExpr::~SCEVZeroExtendExpr() { + SCEVZeroExtends->erase(std::make_pair(Op, Ty)); +} + +void SCEVZeroExtendExpr::print(raw_ostream &OS) const { + OS << "(zext " << *Op->getType() << " " << *Op << " to " << *Ty << ")"; +} + +// SCEVSignExtends - Only allow the creation of one SCEVSignExtendExpr for any +// particular input. Don't use a SCEVHandle here, or else the object will never +// be deleted! +static ManagedStatic<std::map<std::pair<const SCEV*, const Type*>, + SCEVSignExtendExpr*> > SCEVSignExtends; + +SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEVHandle &op, const Type *ty) + : SCEVCastExpr(scSignExtend, op, ty) { + assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) && + (Ty->isInteger() || isa<PointerType>(Ty)) && + "Cannot sign extend non-integer value!"); +} + +SCEVSignExtendExpr::~SCEVSignExtendExpr() { + SCEVSignExtends->erase(std::make_pair(Op, Ty)); +} + +void SCEVSignExtendExpr::print(raw_ostream &OS) const { + OS << "(sext " << *Op->getType() << " " << *Op << " to " << *Ty << ")"; +} + +// SCEVCommExprs - Only allow the creation of one SCEVCommutativeExpr for any +// particular input. Don't use a SCEVHandle here, or else the object will never +// be deleted! +static ManagedStatic<std::map<std::pair<unsigned, std::vector<const SCEV*> >, + SCEVCommutativeExpr*> > SCEVCommExprs; + +SCEVCommutativeExpr::~SCEVCommutativeExpr() { + std::vector<const SCEV*> SCEVOps(Operands.begin(), Operands.end()); + SCEVCommExprs->erase(std::make_pair(getSCEVType(), SCEVOps)); +} + +void SCEVCommutativeExpr::print(raw_ostream &OS) const { + assert(Operands.size() > 1 && "This plus expr shouldn't exist!"); + const char *OpStr = getOperationStr(); + OS << "(" << *Operands[0]; + for (unsigned i = 1, e = Operands.size(); i != e; ++i) + OS << OpStr << *Operands[i]; + OS << ")"; +} + +SCEVHandle SCEVCommutativeExpr:: +replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, + const SCEVHandle &Conc, + ScalarEvolution &SE) const { + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { + SCEVHandle H = + getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE); + if (H != getOperand(i)) { + std::vector<SCEVHandle> NewOps; + NewOps.reserve(getNumOperands()); + for (unsigned j = 0; j != i; ++j) + NewOps.push_back(getOperand(j)); + NewOps.push_back(H); + for (++i; i != e; ++i) + NewOps.push_back(getOperand(i)-> + replaceSymbolicValuesWithConcrete(Sym, Conc, SE)); + + if (isa<SCEVAddExpr>(this)) + return SE.getAddExpr(NewOps); + else if (isa<SCEVMulExpr>(this)) + return SE.getMulExpr(NewOps); + else if (isa<SCEVSMaxExpr>(this)) + return SE.getSMaxExpr(NewOps); + else if (isa<SCEVUMaxExpr>(this)) + return SE.getUMaxExpr(NewOps); + else + assert(0 && "Unknown commutative expr!"); + } + } + return this; +} + +bool SCEVNAryExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { + if (!getOperand(i)->dominates(BB, DT)) + return false; + } + return true; +} + + +// SCEVUDivs - Only allow the creation of one SCEVUDivExpr for any particular +// input. Don't use a SCEVHandle here, or else the object will never be +// deleted! +static ManagedStatic<std::map<std::pair<const SCEV*, const SCEV*>, + SCEVUDivExpr*> > SCEVUDivs; + +SCEVUDivExpr::~SCEVUDivExpr() { + SCEVUDivs->erase(std::make_pair(LHS, RHS)); +} + +bool SCEVUDivExpr::dominates(BasicBlock *BB, DominatorTree *DT) const { + return LHS->dominates(BB, DT) && RHS->dominates(BB, DT); +} + +void SCEVUDivExpr::print(raw_ostream &OS) const { + OS << "(" << *LHS << " /u " << *RHS << ")"; +} + +const Type *SCEVUDivExpr::getType() const { + // In most cases the types of LHS and RHS will be the same, but in some + // crazy cases one or the other may be a pointer. ScalarEvolution doesn't + // depend on the type for correctness, but handling types carefully can + // avoid extra casts in the SCEVExpander. The LHS is more likely to be + // a pointer type than the RHS, so use the RHS' type here. + return RHS->getType(); +} + +// SCEVAddRecExprs - Only allow the creation of one SCEVAddRecExpr for any +// particular input. Don't use a SCEVHandle here, or else the object will never +// be deleted! +static ManagedStatic<std::map<std::pair<const Loop *, + std::vector<const SCEV*> >, + SCEVAddRecExpr*> > SCEVAddRecExprs; + +SCEVAddRecExpr::~SCEVAddRecExpr() { + std::vector<const SCEV*> SCEVOps(Operands.begin(), Operands.end()); + SCEVAddRecExprs->erase(std::make_pair(L, SCEVOps)); +} + +SCEVHandle SCEVAddRecExpr:: +replaceSymbolicValuesWithConcrete(const SCEVHandle &Sym, + const SCEVHandle &Conc, + ScalarEvolution &SE) const { + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { + SCEVHandle H = + getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE); + if (H != getOperand(i)) { + std::vector<SCEVHandle> NewOps; + NewOps.reserve(getNumOperands()); + for (unsigned j = 0; j != i; ++j) + NewOps.push_back(getOperand(j)); + NewOps.push_back(H); + for (++i; i != e; ++i) + NewOps.push_back(getOperand(i)-> + replaceSymbolicValuesWithConcrete(Sym, Conc, SE)); + + return SE.getAddRecExpr(NewOps, L); + } + } + return this; +} + + +bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const { + // This recurrence is invariant w.r.t to QueryLoop iff QueryLoop doesn't + // contain L and if the start is invariant. + // Add recurrences are never invariant in the function-body (null loop). + return QueryLoop && + !QueryLoop->contains(L->getHeader()) && + getOperand(0)->isLoopInvariant(QueryLoop); +} + + +void SCEVAddRecExpr::print(raw_ostream &OS) const { + OS << "{" << *Operands[0]; + for (unsigned i = 1, e = Operands.size(); i != e; ++i) + OS << ",+," << *Operands[i]; + OS << "}<" << L->getHeader()->getName() + ">"; +} + +// SCEVUnknowns - Only allow the creation of one SCEVUnknown for any particular +// value. Don't use a SCEVHandle here, or else the object will never be +// deleted! +static ManagedStatic<std::map<Value*, SCEVUnknown*> > SCEVUnknowns; + +SCEVUnknown::~SCEVUnknown() { SCEVUnknowns->erase(V); } + +bool SCEVUnknown::isLoopInvariant(const Loop *L) const { + // All non-instruction values are loop invariant. All instructions are loop + // invariant if they are not contained in the specified loop. + // Instructions are never considered invariant in the function body + // (null loop) because they are defined within the "loop". + if (Instruction *I = dyn_cast<Instruction>(V)) + return L && !L->contains(I->getParent()); + return true; +} + +bool SCEVUnknown::dominates(BasicBlock *BB, DominatorTree *DT) const { + if (Instruction *I = dyn_cast<Instruction>(getValue())) + return DT->dominates(I->getParent(), BB); + return true; +} + +const Type *SCEVUnknown::getType() const { + return V->getType(); +} + +void SCEVUnknown::print(raw_ostream &OS) const { + WriteAsOperand(OS, V, false); +} + +//===----------------------------------------------------------------------===// +// SCEV Utilities +//===----------------------------------------------------------------------===// + +namespace { + /// SCEVComplexityCompare - Return true if the complexity of the LHS is less + /// than the complexity of the RHS. This comparator is used to canonicalize + /// expressions. + class VISIBILITY_HIDDEN SCEVComplexityCompare { + LoopInfo *LI; + public: + explicit SCEVComplexityCompare(LoopInfo *li) : LI(li) {} + + bool operator()(const SCEV *LHS, const SCEV *RHS) const { + // Primarily, sort the SCEVs by their getSCEVType(). + if (LHS->getSCEVType() != RHS->getSCEVType()) + return LHS->getSCEVType() < RHS->getSCEVType(); + + // Aside from the getSCEVType() ordering, the particular ordering + // isn't very important except that it's beneficial to be consistent, + // so that (a + b) and (b + a) don't end up as different expressions. + + // Sort SCEVUnknown values with some loose heuristics. TODO: This is + // not as complete as it could be. + if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS)) { + const SCEVUnknown *RU = cast<SCEVUnknown>(RHS); + + // Order pointer values after integer values. This helps SCEVExpander + // form GEPs. + if (isa<PointerType>(LU->getType()) && !isa<PointerType>(RU->getType())) + return false; + if (isa<PointerType>(RU->getType()) && !isa<PointerType>(LU->getType())) + return true; + + // Compare getValueID values. + if (LU->getValue()->getValueID() != RU->getValue()->getValueID()) + return LU->getValue()->getValueID() < RU->getValue()->getValueID(); + + // Sort arguments by their position. + if (const Argument *LA = dyn_cast<Argument>(LU->getValue())) { + const Argument *RA = cast<Argument>(RU->getValue()); + return LA->getArgNo() < RA->getArgNo(); + } + + // For instructions, compare their loop depth, and their opcode. + // This is pretty loose. + if (Instruction *LV = dyn_cast<Instruction>(LU->getValue())) { + Instruction *RV = cast<Instruction>(RU->getValue()); + + // Compare loop depths. + if (LI->getLoopDepth(LV->getParent()) != + LI->getLoopDepth(RV->getParent())) + return LI->getLoopDepth(LV->getParent()) < + LI->getLoopDepth(RV->getParent()); + + // Compare opcodes. + if (LV->getOpcode() != RV->getOpcode()) + return LV->getOpcode() < RV->getOpcode(); + + // Compare the number of operands. + if (LV->getNumOperands() != RV->getNumOperands()) + return LV->getNumOperands() < RV->getNumOperands(); + } + + return false; + } + + // Constant sorting doesn't matter since they'll be folded. + if (isa<SCEVConstant>(LHS)) + return false; + + // Lexicographically compare n-ary expressions. + if (const SCEVNAryExpr *LC = dyn_cast<SCEVNAryExpr>(LHS)) { + const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS); + for (unsigned i = 0, e = LC->getNumOperands(); i != e; ++i) { + if (i >= RC->getNumOperands()) + return false; + if (operator()(LC->getOperand(i), RC->getOperand(i))) + return true; + if (operator()(RC->getOperand(i), LC->getOperand(i))) + return false; + } + return LC->getNumOperands() < RC->getNumOperands(); + } + + // Lexicographically compare udiv expressions. + if (const SCEVUDivExpr *LC = dyn_cast<SCEVUDivExpr>(LHS)) { + const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS); + if (operator()(LC->getLHS(), RC->getLHS())) + return true; + if (operator()(RC->getLHS(), LC->getLHS())) + return false; + if (operator()(LC->getRHS(), RC->getRHS())) + return true; + if (operator()(RC->getRHS(), LC->getRHS())) + return false; + return false; + } + + // Compare cast expressions by operand. + if (const SCEVCastExpr *LC = dyn_cast<SCEVCastExpr>(LHS)) { + const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS); + return operator()(LC->getOperand(), RC->getOperand()); + } + + assert(0 && "Unknown SCEV kind!"); + return false; + } + }; +} + +/// GroupByComplexity - Given a list of SCEV objects, order them by their +/// complexity, and group objects of the same complexity together by value. +/// When this routine is finished, we know that any duplicates in the vector are +/// consecutive and that complexity is monotonically increasing. +/// +/// Note that we go take special precautions to ensure that we get determinstic +/// results from this routine. In other words, we don't want the results of +/// this to depend on where the addresses of various SCEV objects happened to +/// land in memory. +/// +static void GroupByComplexity(std::vector<SCEVHandle> &Ops, + LoopInfo *LI) { + if (Ops.size() < 2) return; // Noop + if (Ops.size() == 2) { + // This is the common case, which also happens to be trivially simple. + // Special case it. + if (SCEVComplexityCompare(LI)(Ops[1], Ops[0])) + std::swap(Ops[0], Ops[1]); + return; + } + + // Do the rough sort by complexity. + std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI)); + + // Now that we are sorted by complexity, group elements of the same + // complexity. Note that this is, at worst, N^2, but the vector is likely to + // be extremely short in practice. Note that we take this approach because we + // do not want to depend on the addresses of the objects we are grouping. + for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) { + const SCEV *S = Ops[i]; + unsigned Complexity = S->getSCEVType(); + + // If there are any objects of the same complexity and same value as this + // one, group them. + for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) { + if (Ops[j] == S) { // Found a duplicate. + // Move it to immediately after i'th element. + std::swap(Ops[i+1], Ops[j]); + ++i; // no need to rescan it. + if (i == e-2) return; // Done! + } + } + } +} + + + +//===----------------------------------------------------------------------===// +// Simple SCEV method implementations +//===----------------------------------------------------------------------===// + +/// BinomialCoefficient - Compute BC(It, K). The result has width W. +/// Assume, K > 0. +static SCEVHandle BinomialCoefficient(SCEVHandle It, unsigned K, + ScalarEvolution &SE, + const Type* ResultTy) { + // Handle the simplest case efficiently. + if (K == 1) + return SE.getTruncateOrZeroExtend(It, ResultTy); + + // We are using the following formula for BC(It, K): + // + // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K! + // + // Suppose, W is the bitwidth of the return value. We must be prepared for + // overflow. Hence, we must assure that the result of our computation is + // equal to the accurate one modulo 2^W. Unfortunately, division isn't + // safe in modular arithmetic. + // + // However, this code doesn't use exactly that formula; the formula it uses + // is something like the following, where T is the number of factors of 2 in + // K! (i.e. trailing zeros in the binary representation of K!), and ^ is + // exponentiation: + // + // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T) + // + // This formula is trivially equivalent to the previous formula. However, + // this formula can be implemented much more efficiently. The trick is that + // K! / 2^T is odd, and exact division by an odd number *is* safe in modular + // arithmetic. To do exact division in modular arithmetic, all we have + // to do is multiply by the inverse. Therefore, this step can be done at + // width W. + // + // The next issue is how to safely do the division by 2^T. The way this + // is done is by doing the multiplication step at a width of at least W + T + // bits. This way, the bottom W+T bits of the product are accurate. Then, + // when we perform the division by 2^T (which is equivalent to a right shift + // by T), the bottom W bits are accurate. Extra bits are okay; they'll get + // truncated out after the division by 2^T. + // + // In comparison to just directly using the first formula, this technique + // is much more efficient; using the first formula requires W * K bits, + // but this formula less than W + K bits. Also, the first formula requires + // a division step, whereas this formula only requires multiplies and shifts. + // + // It doesn't matter whether the subtraction step is done in the calculation + // width or the input iteration count's width; if the subtraction overflows, + // the result must be zero anyway. We prefer here to do it in the width of + // the induction variable because it helps a lot for certain cases; CodeGen + // isn't smart enough to ignore the overflow, which leads to much less + // efficient code if the width of the subtraction is wider than the native + // register width. + // + // (It's possible to not widen at all by pulling out factors of 2 before + // the multiplication; for example, K=2 can be calculated as + // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires + // extra arithmetic, so it's not an obvious win, and it gets + // much more complicated for K > 3.) + + // Protection from insane SCEVs; this bound is conservative, + // but it probably doesn't matter. + if (K > 1000) + return SE.getCouldNotCompute(); + + unsigned W = SE.getTypeSizeInBits(ResultTy); + + // Calculate K! / 2^T and T; we divide out the factors of two before + // multiplying for calculating K! / 2^T to avoid overflow. + // Other overflow doesn't matter because we only care about the bottom + // W bits of the result. + APInt OddFactorial(W, 1); + unsigned T = 1; + for (unsigned i = 3; i <= K; ++i) { + APInt Mult(W, i); + unsigned TwoFactors = Mult.countTrailingZeros(); + T += TwoFactors; + Mult = Mult.lshr(TwoFactors); + OddFactorial *= Mult; + } + + // We need at least W + T bits for the multiplication step + unsigned CalculationBits = W + T; + + // Calcuate 2^T, at width T+W. + APInt DivFactor = APInt(CalculationBits, 1).shl(T); + + // Calculate the multiplicative inverse of K! / 2^T; + // this multiplication factor will perform the exact division by + // K! / 2^T. + APInt Mod = APInt::getSignedMinValue(W+1); + APInt MultiplyFactor = OddFactorial.zext(W+1); + MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod); + MultiplyFactor = MultiplyFactor.trunc(W); + + // Calculate the product, at width T+W + const IntegerType *CalculationTy = IntegerType::get(CalculationBits); + SCEVHandle Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); + for (unsigned i = 1; i != K; ++i) { + SCEVHandle S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType())); + Dividend = SE.getMulExpr(Dividend, + SE.getTruncateOrZeroExtend(S, CalculationTy)); + } + + // Divide by 2^T + SCEVHandle DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); + + // Truncate the result, and divide by K! / 2^T. + + return SE.getMulExpr(SE.getConstant(MultiplyFactor), + SE.getTruncateOrZeroExtend(DivResult, ResultTy)); +} + +/// evaluateAtIteration - Return the value of this chain of recurrences at +/// the specified iteration number. We can evaluate this recurrence by +/// multiplying each element in the chain by the binomial coefficient +/// corresponding to it. In other words, we can evaluate {A,+,B,+,C,+,D} as: +/// +/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) +/// +/// where BC(It, k) stands for binomial coefficient. +/// +SCEVHandle SCEVAddRecExpr::evaluateAtIteration(SCEVHandle It, + ScalarEvolution &SE) const { + SCEVHandle Result = getStart(); + for (unsigned i = 1, e = getNumOperands(); i != e; ++i) { + // The computation is correct in the face of overflow provided that the + // multiplication is performed _after_ the evaluation of the binomial + // coefficient. + SCEVHandle Coeff = BinomialCoefficient(It, i, SE, getType()); + if (isa<SCEVCouldNotCompute>(Coeff)) + return Coeff; + + Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff)); + } + return Result; +} + +//===----------------------------------------------------------------------===// +// SCEV Expression folder implementations +//===----------------------------------------------------------------------===// + +SCEVHandle ScalarEvolution::getTruncateExpr(const SCEVHandle &Op, + const Type *Ty) { + assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) && + "This is not a truncating conversion!"); + assert(isSCEVable(Ty) && + "This is not a conversion to a SCEVable type!"); + Ty = getEffectiveSCEVType(Ty); + + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) + return getUnknown( + ConstantExpr::getTrunc(SC->getValue(), Ty)); + + // trunc(trunc(x)) --> trunc(x) + if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) + return getTruncateExpr(ST->getOperand(), Ty); + + // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing + if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op)) + return getTruncateOrSignExtend(SS->getOperand(), Ty); + + // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing + if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) + return getTruncateOrZeroExtend(SZ->getOperand(), Ty); + + // If the input value is a chrec scev made out of constants, truncate + // all of the constants. + if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) { + std::vector<SCEVHandle> Operands; + for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) + Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty)); + return getAddRecExpr(Operands, AddRec->getLoop()); + } + + SCEVTruncateExpr *&Result = (*SCEVTruncates)[std::make_pair(Op, Ty)]; + if (Result == 0) Result = new SCEVTruncateExpr(Op, Ty); + return Result; +} + +SCEVHandle ScalarEvolution::getZeroExtendExpr(const SCEVHandle &Op, + const Type *Ty) { + assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && + "This is not an extending conversion!"); + assert(isSCEVable(Ty) && + "This is not a conversion to a SCEVable type!"); + Ty = getEffectiveSCEVType(Ty); + + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) { + const Type *IntTy = getEffectiveSCEVType(Ty); + Constant *C = ConstantExpr::getZExt(SC->getValue(), IntTy); + if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty); + return getUnknown(C); + } + + // zext(zext(x)) --> zext(x) + if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) + return getZeroExtendExpr(SZ->getOperand(), Ty); + + // If the input value is a chrec scev, and we can prove that the value + // did not overflow the old, smaller, value, we can zero extend all of the + // operands (often constants). This allows analysis of something like + // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; } + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) + if (AR->isAffine()) { + // Check whether the backedge-taken count is SCEVCouldNotCompute. + // Note that this serves two purposes: It filters out loops that are + // simply not analyzable, and it covers the case where this code is + // being called from within backedge-taken count analysis, such that + // attempting to ask for the backedge-taken count would likely result + // in infinite recursion. In the later case, the analysis code will + // cope with a conservative value, and it will take care to purge + // that value once it has finished. + SCEVHandle MaxBECount = getMaxBackedgeTakenCount(AR->getLoop()); + if (!isa<SCEVCouldNotCompute>(MaxBECount)) { + // Manually compute the final value for AR, checking for + // overflow. + SCEVHandle Start = AR->getStart(); + SCEVHandle Step = AR->getStepRecurrence(*this); + + // Check whether the backedge-taken count can be losslessly casted to + // the addrec's type. The count is always unsigned. + SCEVHandle CastedMaxBECount = + getTruncateOrZeroExtend(MaxBECount, Start->getType()); + SCEVHandle RecastedMaxBECount = + getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); + if (MaxBECount == RecastedMaxBECount) { + const Type *WideTy = + IntegerType::get(getTypeSizeInBits(Start->getType()) * 2); + // Check whether Start+Step*MaxBECount has no unsigned overflow. + SCEVHandle ZMul = + getMulExpr(CastedMaxBECount, + getTruncateOrZeroExtend(Step, Start->getType())); + SCEVHandle Add = getAddExpr(Start, ZMul); + SCEVHandle OperandExtendedAdd = + getAddExpr(getZeroExtendExpr(Start, WideTy), + getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy), + getZeroExtendExpr(Step, WideTy))); + if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd) + // Return the expression with the addrec on the outside. + return getAddRecExpr(getZeroExtendExpr(Start, Ty), + getZeroExtendExpr(Step, Ty), + AR->getLoop()); + + // Similar to above, only this time treat the step value as signed. + // This covers loops that count down. + SCEVHandle SMul = + getMulExpr(CastedMaxBECount, + getTruncateOrSignExtend(Step, Start->getType())); + Add = getAddExpr(Start, SMul); + OperandExtendedAdd = + getAddExpr(getZeroExtendExpr(Start, WideTy), + getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy), + getSignExtendExpr(Step, WideTy))); + if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd) + // Return the expression with the addrec on the outside. + return getAddRecExpr(getZeroExtendExpr(Start, Ty), + getSignExtendExpr(Step, Ty), + AR->getLoop()); + } + } + } + + SCEVZeroExtendExpr *&Result = (*SCEVZeroExtends)[std::make_pair(Op, Ty)]; + if (Result == 0) Result = new SCEVZeroExtendExpr(Op, Ty); + return Result; +} + +SCEVHandle ScalarEvolution::getSignExtendExpr(const SCEVHandle &Op, + const Type *Ty) { + assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && + "This is not an extending conversion!"); + assert(isSCEVable(Ty) && + "This is not a conversion to a SCEVable type!"); + Ty = getEffectiveSCEVType(Ty); + + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) { + const Type *IntTy = getEffectiveSCEVType(Ty); + Constant *C = ConstantExpr::getSExt(SC->getValue(), IntTy); + if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty); + return getUnknown(C); + } + + // sext(sext(x)) --> sext(x) + if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op)) + return getSignExtendExpr(SS->getOperand(), Ty); + + // If the input value is a chrec scev, and we can prove that the value + // did not overflow the old, smaller, value, we can sign extend all of the + // operands (often constants). This allows analysis of something like + // this: for (signed char X = 0; X < 100; ++X) { int Y = X; } + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) + if (AR->isAffine()) { + // Check whether the backedge-taken count is SCEVCouldNotCompute. + // Note that this serves two purposes: It filters out loops that are + // simply not analyzable, and it covers the case where this code is + // being called from within backedge-taken count analysis, such that + // attempting to ask for the backedge-taken count would likely result + // in infinite recursion. In the later case, the analysis code will + // cope with a conservative value, and it will take care to purge + // that value once it has finished. + SCEVHandle MaxBECount = getMaxBackedgeTakenCount(AR->getLoop()); + if (!isa<SCEVCouldNotCompute>(MaxBECount)) { + // Manually compute the final value for AR, checking for + // overflow. + SCEVHandle Start = AR->getStart(); + SCEVHandle Step = AR->getStepRecurrence(*this); + + // Check whether the backedge-taken count can be losslessly casted to + // the addrec's type. The count is always unsigned. + SCEVHandle CastedMaxBECount = + getTruncateOrZeroExtend(MaxBECount, Start->getType()); + SCEVHandle RecastedMaxBECount = + getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); + if (MaxBECount == RecastedMaxBECount) { + const Type *WideTy = + IntegerType::get(getTypeSizeInBits(Start->getType()) * 2); + // Check whether Start+Step*MaxBECount has no signed overflow. + SCEVHandle SMul = + getMulExpr(CastedMaxBECount, + getTruncateOrSignExtend(Step, Start->getType())); + SCEVHandle Add = getAddExpr(Start, SMul); + SCEVHandle OperandExtendedAdd = + getAddExpr(getSignExtendExpr(Start, WideTy), + getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy), + getSignExtendExpr(Step, WideTy))); + if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd) + // Return the expression with the addrec on the outside. + return getAddRecExpr(getSignExtendExpr(Start, Ty), + getSignExtendExpr(Step, Ty), + AR->getLoop()); + } + } + } + + SCEVSignExtendExpr *&Result = (*SCEVSignExtends)[std::make_pair(Op, Ty)]; + if (Result == 0) Result = new SCEVSignExtendExpr(Op, Ty); + return Result; +} + +/// getAddExpr - Get a canonical add expression, or something simpler if +/// possible. +SCEVHandle ScalarEvolution::getAddExpr(std::vector<SCEVHandle> &Ops) { + assert(!Ops.empty() && "Cannot get empty add!"); + if (Ops.size() == 1) return Ops[0]; +#ifndef NDEBUG + for (unsigned i = 1, e = Ops.size(); i != e; ++i) + assert(getEffectiveSCEVType(Ops[i]->getType()) == + getEffectiveSCEVType(Ops[0]->getType()) && + "SCEVAddExpr operand types don't match!"); +#endif + + // Sort by complexity, this groups all similar expression types together. + GroupByComplexity(Ops, LI); + + // If there are any constants, fold them together. + unsigned Idx = 0; + if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { + ++Idx; + assert(Idx < Ops.size()); + while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { + // We found two constants, fold them together! + ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() + + RHSC->getValue()->getValue()); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast<SCEVConstant>(Ops[0]); + } + + // If we are left with a constant zero being added, strip it off. + if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) { + Ops.erase(Ops.begin()); + --Idx; + } + } + + if (Ops.size() == 1) return Ops[0]; + + // Okay, check to see if the same value occurs in the operand list twice. If + // so, merge them together into an multiply expression. Since we sorted the + // list, these values are required to be adjacent. + const Type *Ty = Ops[0]->getType(); + for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) + if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2 + // Found a match, merge the two values into a multiply, and add any + // remaining values to the result. + SCEVHandle Two = getIntegerSCEV(2, Ty); + SCEVHandle Mul = getMulExpr(Ops[i], Two); + if (Ops.size() == 2) + return Mul; + Ops.erase(Ops.begin()+i, Ops.begin()+i+2); + Ops.push_back(Mul); + return getAddExpr(Ops); + } + + // Check for truncates. If all the operands are truncated from the same + // type, see if factoring out the truncate would permit the result to be + // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n) + // if the contents of the resulting outer trunc fold to something simple. + for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) { + const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]); + const Type *DstType = Trunc->getType(); + const Type *SrcType = Trunc->getOperand()->getType(); + std::vector<SCEVHandle> LargeOps; + bool Ok = true; + // Check all the operands to see if they can be represented in the + // source type of the truncate. + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) { + if (T->getOperand()->getType() != SrcType) { + Ok = false; + break; + } + LargeOps.push_back(T->getOperand()); + } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) { + // This could be either sign or zero extension, but sign extension + // is much more likely to be foldable here. + LargeOps.push_back(getSignExtendExpr(C, SrcType)); + } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) { + std::vector<SCEVHandle> LargeMulOps; + for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) { + if (const SCEVTruncateExpr *T = + dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) { + if (T->getOperand()->getType() != SrcType) { + Ok = false; + break; + } + LargeMulOps.push_back(T->getOperand()); + } else if (const SCEVConstant *C = + dyn_cast<SCEVConstant>(M->getOperand(j))) { + // This could be either sign or zero extension, but sign extension + // is much more likely to be foldable here. + LargeMulOps.push_back(getSignExtendExpr(C, SrcType)); + } else { + Ok = false; + break; + } + } + if (Ok) + LargeOps.push_back(getMulExpr(LargeMulOps)); + } else { + Ok = false; + break; + } + } + if (Ok) { + // Evaluate the expression in the larger type. + SCEVHandle Fold = getAddExpr(LargeOps); + // If it folds to something simple, use it. Otherwise, don't. + if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold)) + return getTruncateExpr(Fold, DstType); + } + } + + // Skip past any other cast SCEVs. + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr) + ++Idx; + + // If there are add operands they would be next. + if (Idx < Ops.size()) { + bool DeletedAdd = false; + while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) { + // If we have an add, expand the add operands onto the end of the operands + // list. + Ops.insert(Ops.end(), Add->op_begin(), Add->op_end()); + Ops.erase(Ops.begin()+Idx); + DeletedAdd = true; + } + + // If we deleted at least one add, we added operands to the end of the list, + // and they are not necessarily sorted. Recurse to resort and resimplify + // any operands we just aquired. + if (DeletedAdd) + return getAddExpr(Ops); + } + + // Skip over the add expression until we get to a multiply. + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) + ++Idx; + + // If we are adding something to a multiply expression, make sure the + // something is not already an operand of the multiply. If so, merge it into + // the multiply. + for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) { + const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]); + for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) { + const SCEV *MulOpSCEV = Mul->getOperand(MulOp); + for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp) + if (MulOpSCEV == Ops[AddOp] && !isa<SCEVConstant>(MulOpSCEV)) { + // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1)) + SCEVHandle InnerMul = Mul->getOperand(MulOp == 0); + if (Mul->getNumOperands() != 2) { + // If the multiply has more than two operands, we must get the + // Y*Z term. + std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end()); + MulOps.erase(MulOps.begin()+MulOp); + InnerMul = getMulExpr(MulOps); + } + SCEVHandle One = getIntegerSCEV(1, Ty); + SCEVHandle AddOne = getAddExpr(InnerMul, One); + SCEVHandle OuterMul = getMulExpr(AddOne, Ops[AddOp]); + if (Ops.size() == 2) return OuterMul; + if (AddOp < Idx) { + Ops.erase(Ops.begin()+AddOp); + Ops.erase(Ops.begin()+Idx-1); + } else { + Ops.erase(Ops.begin()+Idx); + Ops.erase(Ops.begin()+AddOp-1); + } + Ops.push_back(OuterMul); + return getAddExpr(Ops); + } + + // Check this multiply against other multiplies being added together. + for (unsigned OtherMulIdx = Idx+1; + OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]); + ++OtherMulIdx) { + const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]); + // If MulOp occurs in OtherMul, we can fold the two multiplies + // together. + for (unsigned OMulOp = 0, e = OtherMul->getNumOperands(); + OMulOp != e; ++OMulOp) + if (OtherMul->getOperand(OMulOp) == MulOpSCEV) { + // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E)) + SCEVHandle InnerMul1 = Mul->getOperand(MulOp == 0); + if (Mul->getNumOperands() != 2) { + std::vector<SCEVHandle> MulOps(Mul->op_begin(), Mul->op_end()); + MulOps.erase(MulOps.begin()+MulOp); + InnerMul1 = getMulExpr(MulOps); + } + SCEVHandle InnerMul2 = OtherMul->getOperand(OMulOp == 0); + if (OtherMul->getNumOperands() != 2) { + std::vector<SCEVHandle> MulOps(OtherMul->op_begin(), + OtherMul->op_end()); + MulOps.erase(MulOps.begin()+OMulOp); + InnerMul2 = getMulExpr(MulOps); + } + SCEVHandle InnerMulSum = getAddExpr(InnerMul1,InnerMul2); + SCEVHandle OuterMul = getMulExpr(MulOpSCEV, InnerMulSum); + if (Ops.size() == 2) return OuterMul; + Ops.erase(Ops.begin()+Idx); + Ops.erase(Ops.begin()+OtherMulIdx-1); + Ops.push_back(OuterMul); + return getAddExpr(Ops); + } + } + } + } + + // If there are any add recurrences in the operands list, see if any other + // added values are loop invariant. If so, we can fold them into the + // recurrence. + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr) + ++Idx; + + // Scan over all recurrences, trying to fold loop invariants into them. + for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) { + // Scan all of the other operands to this add and add them to the vector if + // they are loop invariant w.r.t. the recurrence. + std::vector<SCEVHandle> LIOps; + const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + if (Ops[i]->isLoopInvariant(AddRec->getLoop())) { + LIOps.push_back(Ops[i]); + Ops.erase(Ops.begin()+i); + --i; --e; + } + + // If we found some loop invariants, fold them into the recurrence. + if (!LIOps.empty()) { + // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step} + LIOps.push_back(AddRec->getStart()); + + std::vector<SCEVHandle> AddRecOps(AddRec->op_begin(), AddRec->op_end()); + AddRecOps[0] = getAddExpr(LIOps); + + SCEVHandle NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop()); + // If all of the other operands were loop invariant, we are done. + if (Ops.size() == 1) return NewRec; + + // Otherwise, add the folded AddRec by the non-liv parts. + for (unsigned i = 0;; ++i) + if (Ops[i] == AddRec) { + Ops[i] = NewRec; + break; + } + return getAddExpr(Ops); + } + + // Okay, if there weren't any loop invariants to be folded, check to see if + // there are multiple AddRec's with the same loop induction variable being + // added together. If so, we can fold them. + for (unsigned OtherIdx = Idx+1; + OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx) + if (OtherIdx != Idx) { + const SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]); + if (AddRec->getLoop() == OtherAddRec->getLoop()) { + // Other + {A,+,B} + {C,+,D} --> Other + {A+C,+,B+D} + std::vector<SCEVHandle> NewOps(AddRec->op_begin(), AddRec->op_end()); + for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) { + if (i >= NewOps.size()) { + NewOps.insert(NewOps.end(), OtherAddRec->op_begin()+i, + OtherAddRec->op_end()); + break; + } + NewOps[i] = getAddExpr(NewOps[i], OtherAddRec->getOperand(i)); + } + SCEVHandle NewAddRec = getAddRecExpr(NewOps, AddRec->getLoop()); + + if (Ops.size() == 2) return NewAddRec; + + Ops.erase(Ops.begin()+Idx); + Ops.erase(Ops.begin()+OtherIdx-1); + Ops.push_back(NewAddRec); + return getAddExpr(Ops); + } + } + + // Otherwise couldn't fold anything into this recurrence. Move onto the + // next one. + } + + // Okay, it looks like we really DO need an add expr. Check to see if we + // already have one, otherwise create a new one. + std::vector<const SCEV*> SCEVOps(Ops.begin(), Ops.end()); + SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scAddExpr, + SCEVOps)]; + if (Result == 0) Result = new SCEVAddExpr(Ops); + return Result; +} + + +/// getMulExpr - Get a canonical multiply expression, or something simpler if +/// possible. +SCEVHandle ScalarEvolution::getMulExpr(std::vector<SCEVHandle> &Ops) { + assert(!Ops.empty() && "Cannot get empty mul!"); +#ifndef NDEBUG + for (unsigned i = 1, e = Ops.size(); i != e; ++i) + assert(getEffectiveSCEVType(Ops[i]->getType()) == + getEffectiveSCEVType(Ops[0]->getType()) && + "SCEVMulExpr operand types don't match!"); +#endif + + // Sort by complexity, this groups all similar expression types together. + GroupByComplexity(Ops, LI); + + // If there are any constants, fold them together. + unsigned Idx = 0; + if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { + + // C1*(C2+V) -> C1*C2 + C1*V + if (Ops.size() == 2) + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) + if (Add->getNumOperands() == 2 && + isa<SCEVConstant>(Add->getOperand(0))) + return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), + getMulExpr(LHSC, Add->getOperand(1))); + + + ++Idx; + while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { + // We found two constants, fold them together! + ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() * + RHSC->getValue()->getValue()); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast<SCEVConstant>(Ops[0]); + } + + // If we are left with a constant one being multiplied, strip it off. + if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) { + Ops.erase(Ops.begin()); + --Idx; + } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) { + // If we have a multiply of zero, it will always be zero. + return Ops[0]; + } + } + + // Skip over the add expression until we get to a multiply. + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) + ++Idx; + + if (Ops.size() == 1) + return Ops[0]; + + // If there are mul operands inline them all into this expression. + if (Idx < Ops.size()) { + bool DeletedMul = false; + while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) { + // If we have an mul, expand the mul operands onto the end of the operands + // list. + Ops.insert(Ops.end(), Mul->op_begin(), Mul->op_end()); + Ops.erase(Ops.begin()+Idx); + DeletedMul = true; + } + + // If we deleted at least one mul, we added operands to the end of the list, + // and they are not necessarily sorted. Recurse to resort and resimplify + // any operands we just aquired. + if (DeletedMul) + return getMulExpr(Ops); + } + + // If there are any add recurrences in the operands list, see if any other + // added values are loop invariant. If so, we can fold them into the + // recurrence. + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr) + ++Idx; + + // Scan over all recurrences, trying to fold loop invariants into them. + for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) { + // Scan all of the other operands to this mul and add them to the vector if + // they are loop invariant w.r.t. the recurrence. + std::vector<SCEVHandle> LIOps; + const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + if (Ops[i]->isLoopInvariant(AddRec->getLoop())) { + LIOps.push_back(Ops[i]); + Ops.erase(Ops.begin()+i); + --i; --e; + } + + // If we found some loop invariants, fold them into the recurrence. + if (!LIOps.empty()) { + // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step} + std::vector<SCEVHandle> NewOps; + NewOps.reserve(AddRec->getNumOperands()); + if (LIOps.size() == 1) { + const SCEV *Scale = LIOps[0]; + for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) + NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i))); + } else { + for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { + std::vector<SCEVHandle> MulOps(LIOps); + MulOps.push_back(AddRec->getOperand(i)); + NewOps.push_back(getMulExpr(MulOps)); + } + } + + SCEVHandle NewRec = getAddRecExpr(NewOps, AddRec->getLoop()); + + // If all of the other operands were loop invariant, we are done. + if (Ops.size() == 1) return NewRec; + + // Otherwise, multiply the folded AddRec by the non-liv parts. + for (unsigned i = 0;; ++i) + if (Ops[i] == AddRec) { + Ops[i] = NewRec; + break; + } + return getMulExpr(Ops); + } + + // Okay, if there weren't any loop invariants to be folded, check to see if + // there are multiple AddRec's with the same loop induction variable being + // multiplied together. If so, we can fold them. + for (unsigned OtherIdx = Idx+1; + OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx) + if (OtherIdx != Idx) { + const SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]); + if (AddRec->getLoop() == OtherAddRec->getLoop()) { + // F * G --> {A,+,B} * {C,+,D} --> {A*C,+,F*D + G*B + B*D} + const SCEVAddRecExpr *F = AddRec, *G = OtherAddRec; + SCEVHandle NewStart = getMulExpr(F->getStart(), + G->getStart()); + SCEVHandle B = F->getStepRecurrence(*this); + SCEVHandle D = G->getStepRecurrence(*this); + SCEVHandle NewStep = getAddExpr(getMulExpr(F, D), + getMulExpr(G, B), + getMulExpr(B, D)); + SCEVHandle NewAddRec = getAddRecExpr(NewStart, NewStep, + F->getLoop()); + if (Ops.size() == 2) return NewAddRec; + + Ops.erase(Ops.begin()+Idx); + Ops.erase(Ops.begin()+OtherIdx-1); + Ops.push_back(NewAddRec); + return getMulExpr(Ops); + } + } + + // Otherwise couldn't fold anything into this recurrence. Move onto the + // next one. + } + + // Okay, it looks like we really DO need an mul expr. Check to see if we + // already have one, otherwise create a new one. + std::vector<const SCEV*> SCEVOps(Ops.begin(), Ops.end()); + SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scMulExpr, + SCEVOps)]; + if (Result == 0) + Result = new SCEVMulExpr(Ops); + return Result; +} + +/// getUDivExpr - Get a canonical multiply expression, or something simpler if +/// possible. +SCEVHandle ScalarEvolution::getUDivExpr(const SCEVHandle &LHS, + const SCEVHandle &RHS) { + assert(getEffectiveSCEVType(LHS->getType()) == + getEffectiveSCEVType(RHS->getType()) && + "SCEVUDivExpr operand types don't match!"); + + if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { + if (RHSC->getValue()->equalsInt(1)) + return LHS; // X udiv 1 --> x + if (RHSC->isZero()) + return getIntegerSCEV(0, LHS->getType()); // value is undefined + + // Determine if the division can be folded into the operands of + // its operands. + // TODO: Generalize this to non-constants by using known-bits information. + const Type *Ty = LHS->getType(); + unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros(); + unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ; + // For non-power-of-two values, effectively round the value up to the + // nearest power of two. + if (!RHSC->getValue()->getValue().isPowerOf2()) + ++MaxShiftAmt; + const IntegerType *ExtTy = + IntegerType::get(getTypeSizeInBits(Ty) + MaxShiftAmt); + // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded. + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS)) + if (const SCEVConstant *Step = + dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) + if (!Step->getValue()->getValue() + .urem(RHSC->getValue()->getValue()) && + getZeroExtendExpr(AR, ExtTy) == + getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), + getZeroExtendExpr(Step, ExtTy), + AR->getLoop())) { + std::vector<SCEVHandle> Operands; + for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i) + Operands.push_back(getUDivExpr(AR->getOperand(i), RHS)); + return getAddRecExpr(Operands, AR->getLoop()); + } + // (A*B)/C --> A*(B/C) if safe and B/C can be folded. + if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) { + std::vector<SCEVHandle> Operands; + for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) + Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy)); + if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) + // Find an operand that's safely divisible. + for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) { + SCEVHandle Op = M->getOperand(i); + SCEVHandle Div = getUDivExpr(Op, RHSC); + if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) { + Operands = M->getOperands(); + Operands[i] = Div; + return getMulExpr(Operands); + } + } + } + // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. + if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(LHS)) { + std::vector<SCEVHandle> Operands; + for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) + Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy)); + if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) { + Operands.clear(); + for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) { + SCEVHandle Op = getUDivExpr(A->getOperand(i), RHS); + if (isa<SCEVUDivExpr>(Op) || getMulExpr(Op, RHS) != A->getOperand(i)) + break; + Operands.push_back(Op); + } + if (Operands.size() == A->getNumOperands()) + return getAddExpr(Operands); + } + } + + // Fold if both operands are constant. + if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) { + Constant *LHSCV = LHSC->getValue(); + Constant *RHSCV = RHSC->getValue(); + return getUnknown(ConstantExpr::getUDiv(LHSCV, RHSCV)); + } + } + + SCEVUDivExpr *&Result = (*SCEVUDivs)[std::make_pair(LHS, RHS)]; + if (Result == 0) Result = new SCEVUDivExpr(LHS, RHS); + return Result; +} + + +/// getAddRecExpr - Get an add recurrence expression for the specified loop. +/// Simplify the expression as much as possible. +SCEVHandle ScalarEvolution::getAddRecExpr(const SCEVHandle &Start, + const SCEVHandle &Step, const Loop *L) { + std::vector<SCEVHandle> Operands; + Operands.push_back(Start); + if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step)) + if (StepChrec->getLoop() == L) { + Operands.insert(Operands.end(), StepChrec->op_begin(), + StepChrec->op_end()); + return getAddRecExpr(Operands, L); + } + + Operands.push_back(Step); + return getAddRecExpr(Operands, L); +} + +/// getAddRecExpr - Get an add recurrence expression for the specified loop. +/// Simplify the expression as much as possible. +SCEVHandle ScalarEvolution::getAddRecExpr(std::vector<SCEVHandle> &Operands, + const Loop *L) { + if (Operands.size() == 1) return Operands[0]; +#ifndef NDEBUG + for (unsigned i = 1, e = Operands.size(); i != e; ++i) + assert(getEffectiveSCEVType(Operands[i]->getType()) == + getEffectiveSCEVType(Operands[0]->getType()) && + "SCEVAddRecExpr operand types don't match!"); +#endif + + if (Operands.back()->isZero()) { + Operands.pop_back(); + return getAddRecExpr(Operands, L); // {X,+,0} --> X + } + + // Canonicalize nested AddRecs in by nesting them in order of loop depth. + if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) { + const Loop* NestedLoop = NestedAR->getLoop(); + if (L->getLoopDepth() < NestedLoop->getLoopDepth()) { + std::vector<SCEVHandle> NestedOperands(NestedAR->op_begin(), + NestedAR->op_end()); + SCEVHandle NestedARHandle(NestedAR); + Operands[0] = NestedAR->getStart(); + NestedOperands[0] = getAddRecExpr(Operands, L); + return getAddRecExpr(NestedOperands, NestedLoop); + } + } + + std::vector<const SCEV*> SCEVOps(Operands.begin(), Operands.end()); + SCEVAddRecExpr *&Result = (*SCEVAddRecExprs)[std::make_pair(L, SCEVOps)]; + if (Result == 0) Result = new SCEVAddRecExpr(Operands, L); + return Result; +} + +SCEVHandle ScalarEvolution::getSMaxExpr(const SCEVHandle &LHS, + const SCEVHandle &RHS) { + std::vector<SCEVHandle> Ops; + Ops.push_back(LHS); + Ops.push_back(RHS); + return getSMaxExpr(Ops); +} + +SCEVHandle ScalarEvolution::getSMaxExpr(std::vector<SCEVHandle> Ops) { + assert(!Ops.empty() && "Cannot get empty smax!"); + if (Ops.size() == 1) return Ops[0]; +#ifndef NDEBUG + for (unsigned i = 1, e = Ops.size(); i != e; ++i) + assert(getEffectiveSCEVType(Ops[i]->getType()) == + getEffectiveSCEVType(Ops[0]->getType()) && + "SCEVSMaxExpr operand types don't match!"); +#endif + + // Sort by complexity, this groups all similar expression types together. + GroupByComplexity(Ops, LI); + + // If there are any constants, fold them together. + unsigned Idx = 0; + if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { + ++Idx; + assert(Idx < Ops.size()); + while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { + // We found two constants, fold them together! + ConstantInt *Fold = ConstantInt::get( + APIntOps::smax(LHSC->getValue()->getValue(), + RHSC->getValue()->getValue())); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast<SCEVConstant>(Ops[0]); + } + + // If we are left with a constant -inf, strip it off. + if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) { + Ops.erase(Ops.begin()); + --Idx; + } + } + + if (Ops.size() == 1) return Ops[0]; + + // Find the first SMax + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr) + ++Idx; + + // Check to see if one of the operands is an SMax. If so, expand its operands + // onto our operand list, and recurse to simplify. + if (Idx < Ops.size()) { + bool DeletedSMax = false; + while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) { + Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end()); + Ops.erase(Ops.begin()+Idx); + DeletedSMax = true; + } + + if (DeletedSMax) + return getSMaxExpr(Ops); + } + + // Okay, check to see if the same value occurs in the operand list twice. If + // so, delete one. Since we sorted the list, these values are required to + // be adjacent. + for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) + if (Ops[i] == Ops[i+1]) { // X smax Y smax Y --> X smax Y + Ops.erase(Ops.begin()+i, Ops.begin()+i+1); + --i; --e; + } + + if (Ops.size() == 1) return Ops[0]; + + assert(!Ops.empty() && "Reduced smax down to nothing!"); + + // Okay, it looks like we really DO need an smax expr. Check to see if we + // already have one, otherwise create a new one. + std::vector<const SCEV*> SCEVOps(Ops.begin(), Ops.end()); + SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scSMaxExpr, + SCEVOps)]; + if (Result == 0) Result = new SCEVSMaxExpr(Ops); + return Result; +} + +SCEVHandle ScalarEvolution::getUMaxExpr(const SCEVHandle &LHS, + const SCEVHandle &RHS) { + std::vector<SCEVHandle> Ops; + Ops.push_back(LHS); + Ops.push_back(RHS); + return getUMaxExpr(Ops); +} + +SCEVHandle ScalarEvolution::getUMaxExpr(std::vector<SCEVHandle> Ops) { + assert(!Ops.empty() && "Cannot get empty umax!"); + if (Ops.size() == 1) return Ops[0]; +#ifndef NDEBUG + for (unsigned i = 1, e = Ops.size(); i != e; ++i) + assert(getEffectiveSCEVType(Ops[i]->getType()) == + getEffectiveSCEVType(Ops[0]->getType()) && + "SCEVUMaxExpr operand types don't match!"); +#endif + + // Sort by complexity, this groups all similar expression types together. + GroupByComplexity(Ops, LI); + + // If there are any constants, fold them together. + unsigned Idx = 0; + if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { + ++Idx; + assert(Idx < Ops.size()); + while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { + // We found two constants, fold them together! + ConstantInt *Fold = ConstantInt::get( + APIntOps::umax(LHSC->getValue()->getValue(), + RHSC->getValue()->getValue())); + Ops[0] = getConstant(Fold); + Ops.erase(Ops.begin()+1); // Erase the folded element + if (Ops.size() == 1) return Ops[0]; + LHSC = cast<SCEVConstant>(Ops[0]); + } + + // If we are left with a constant zero, strip it off. + if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) { + Ops.erase(Ops.begin()); + --Idx; + } + } + + if (Ops.size() == 1) return Ops[0]; + + // Find the first UMax + while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr) + ++Idx; + + // Check to see if one of the operands is a UMax. If so, expand its operands + // onto our operand list, and recurse to simplify. + if (Idx < Ops.size()) { + bool DeletedUMax = false; + while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) { + Ops.insert(Ops.end(), UMax->op_begin(), UMax->op_end()); + Ops.erase(Ops.begin()+Idx); + DeletedUMax = true; + } + + if (DeletedUMax) + return getUMaxExpr(Ops); + } + + // Okay, check to see if the same value occurs in the operand list twice. If + // so, delete one. Since we sorted the list, these values are required to + // be adjacent. + for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) + if (Ops[i] == Ops[i+1]) { // X umax Y umax Y --> X umax Y + Ops.erase(Ops.begin()+i, Ops.begin()+i+1); + --i; --e; + } + + if (Ops.size() == 1) return Ops[0]; + + assert(!Ops.empty() && "Reduced umax down to nothing!"); + + // Okay, it looks like we really DO need a umax expr. Check to see if we + // already have one, otherwise create a new one. + std::vector<const SCEV*> SCEVOps(Ops.begin(), Ops.end()); + SCEVCommutativeExpr *&Result = (*SCEVCommExprs)[std::make_pair(scUMaxExpr, + SCEVOps)]; + if (Result == 0) Result = new SCEVUMaxExpr(Ops); + return Result; +} + +SCEVHandle ScalarEvolution::getUnknown(Value *V) { + if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) + return getConstant(CI); + if (isa<ConstantPointerNull>(V)) + return getIntegerSCEV(0, V->getType()); + SCEVUnknown *&Result = (*SCEVUnknowns)[V]; + if (Result == 0) Result = new SCEVUnknown(V); + return Result; +} + +//===----------------------------------------------------------------------===// +// Basic SCEV Analysis and PHI Idiom Recognition Code +// + +/// isSCEVable - Test if values of the given type are analyzable within +/// the SCEV framework. This primarily includes integer types, and it +/// can optionally include pointer types if the ScalarEvolution class +/// has access to target-specific information. +bool ScalarEvolution::isSCEVable(const Type *Ty) const { + // Integers are always SCEVable. + if (Ty->isInteger()) + return true; + + // Pointers are SCEVable if TargetData information is available + // to provide pointer size information. + if (isa<PointerType>(Ty)) + return TD != NULL; + + // Otherwise it's not SCEVable. + return false; +} + +/// getTypeSizeInBits - Return the size in bits of the specified type, +/// for which isSCEVable must return true. +uint64_t ScalarEvolution::getTypeSizeInBits(const Type *Ty) const { + assert(isSCEVable(Ty) && "Type is not SCEVable!"); + + // If we have a TargetData, use it! + if (TD) + return TD->getTypeSizeInBits(Ty); + + // Otherwise, we support only integer types. + assert(Ty->isInteger() && "isSCEVable permitted a non-SCEVable type!"); + return Ty->getPrimitiveSizeInBits(); +} + +/// getEffectiveSCEVType - Return a type with the same bitwidth as +/// the given type and which represents how SCEV will treat the given +/// type, for which isSCEVable must return true. For pointer types, +/// this is the pointer-sized integer type. +const Type *ScalarEvolution::getEffectiveSCEVType(const Type *Ty) const { + assert(isSCEVable(Ty) && "Type is not SCEVable!"); + + if (Ty->isInteger()) + return Ty; + + assert(isa<PointerType>(Ty) && "Unexpected non-pointer non-integer type!"); + return TD->getIntPtrType(); +} + +SCEVHandle ScalarEvolution::getCouldNotCompute() { + return UnknownValue; +} + +/// hasSCEV - Return true if the SCEV for this value has already been +/// computed. +bool ScalarEvolution::hasSCEV(Value *V) const { + return Scalars.count(V); +} + +/// getSCEV - Return an existing SCEV if it exists, otherwise analyze the +/// expression and create a new one. +SCEVHandle ScalarEvolution::getSCEV(Value *V) { + assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); + + std::map<SCEVCallbackVH, SCEVHandle>::iterator I = Scalars.find(V); + if (I != Scalars.end()) return I->second; + SCEVHandle S = createSCEV(V); + Scalars.insert(std::make_pair(SCEVCallbackVH(V, this), S)); + return S; +} + +/// getIntegerSCEV - Given an integer or FP type, create a constant for the +/// specified signed integer value and return a SCEV for the constant. +SCEVHandle ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) { + Ty = getEffectiveSCEVType(Ty); + Constant *C; + if (Val == 0) + C = Constant::getNullValue(Ty); + else if (Ty->isFloatingPoint()) + C = ConstantFP::get(APFloat(Ty==Type::FloatTy ? APFloat::IEEEsingle : + APFloat::IEEEdouble, Val)); + else + C = ConstantInt::get(Ty, Val); + return getUnknown(C); +} + +/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V +/// +SCEVHandle ScalarEvolution::getNegativeSCEV(const SCEVHandle &V) { + if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) + return getUnknown(ConstantExpr::getNeg(VC->getValue())); + + const Type *Ty = V->getType(); + Ty = getEffectiveSCEVType(Ty); + return getMulExpr(V, getConstant(ConstantInt::getAllOnesValue(Ty))); +} + +/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V +SCEVHandle ScalarEvolution::getNotSCEV(const SCEVHandle &V) { + if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) + return getUnknown(ConstantExpr::getNot(VC->getValue())); + + const Type *Ty = V->getType(); + Ty = getEffectiveSCEVType(Ty); + SCEVHandle AllOnes = getConstant(ConstantInt::getAllOnesValue(Ty)); + return getMinusSCEV(AllOnes, V); +} + +/// getMinusSCEV - Return a SCEV corresponding to LHS - RHS. +/// +SCEVHandle ScalarEvolution::getMinusSCEV(const SCEVHandle &LHS, + const SCEVHandle &RHS) { + // X - Y --> X + -Y + return getAddExpr(LHS, getNegativeSCEV(RHS)); +} + +/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the +/// input value to the specified type. If the type must be extended, it is zero +/// extended. +SCEVHandle +ScalarEvolution::getTruncateOrZeroExtend(const SCEVHandle &V, + const Type *Ty) { + const Type *SrcTy = V->getType(); + assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) && + (Ty->isInteger() || (TD && isa<PointerType>(Ty))) && + "Cannot truncate or zero extend with non-integer arguments!"); + if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) + return V; // No conversion + if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty)) + return getTruncateExpr(V, Ty); + return getZeroExtendExpr(V, Ty); +} + +/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the +/// input value to the specified type. If the type must be extended, it is sign +/// extended. +SCEVHandle +ScalarEvolution::getTruncateOrSignExtend(const SCEVHandle &V, + const Type *Ty) { + const Type *SrcTy = V->getType(); + assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) && + (Ty->isInteger() || (TD && isa<PointerType>(Ty))) && + "Cannot truncate or zero extend with non-integer arguments!"); + if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) + return V; // No conversion + if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty)) + return getTruncateExpr(V, Ty); + return getSignExtendExpr(V, Ty); +} + +/// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the +/// input value to the specified type. If the type must be extended, it is zero +/// extended. The conversion must not be narrowing. +SCEVHandle +ScalarEvolution::getNoopOrZeroExtend(const SCEVHandle &V, const Type *Ty) { + const Type *SrcTy = V->getType(); + assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) && + (Ty->isInteger() || (TD && isa<PointerType>(Ty))) && + "Cannot noop or zero extend with non-integer arguments!"); + assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && + "getNoopOrZeroExtend cannot truncate!"); + if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) + return V; // No conversion + return getZeroExtendExpr(V, Ty); +} + +/// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the +/// input value to the specified type. If the type must be extended, it is sign +/// extended. The conversion must not be narrowing. +SCEVHandle +ScalarEvolution::getNoopOrSignExtend(const SCEVHandle &V, const Type *Ty) { + const Type *SrcTy = V->getType(); + assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) && + (Ty->isInteger() || (TD && isa<PointerType>(Ty))) && + "Cannot noop or sign extend with non-integer arguments!"); + assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && + "getNoopOrSignExtend cannot truncate!"); + if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) + return V; // No conversion + return getSignExtendExpr(V, Ty); +} + +/// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the +/// input value to the specified type. The conversion must not be widening. +SCEVHandle +ScalarEvolution::getTruncateOrNoop(const SCEVHandle &V, const Type *Ty) { + const Type *SrcTy = V->getType(); + assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) && + (Ty->isInteger() || (TD && isa<PointerType>(Ty))) && + "Cannot truncate or noop with non-integer arguments!"); + assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) && + "getTruncateOrNoop cannot extend!"); + if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) + return V; // No conversion + return getTruncateExpr(V, Ty); +} + +/// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value for +/// the specified instruction and replaces any references to the symbolic value +/// SymName with the specified value. This is used during PHI resolution. +void ScalarEvolution:: +ReplaceSymbolicValueWithConcrete(Instruction *I, const SCEVHandle &SymName, + const SCEVHandle &NewVal) { + std::map<SCEVCallbackVH, SCEVHandle>::iterator SI = + Scalars.find(SCEVCallbackVH(I, this)); + if (SI == Scalars.end()) return; + + SCEVHandle NV = + SI->second->replaceSymbolicValuesWithConcrete(SymName, NewVal, *this); + if (NV == SI->second) return; // No change. + + SI->second = NV; // Update the scalars map! + + // Any instruction values that use this instruction might also need to be + // updated! + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); + UI != E; ++UI) + ReplaceSymbolicValueWithConcrete(cast<Instruction>(*UI), SymName, NewVal); +} + +/// createNodeForPHI - PHI nodes have two cases. Either the PHI node exists in +/// a loop header, making it a potential recurrence, or it doesn't. +/// +SCEVHandle ScalarEvolution::createNodeForPHI(PHINode *PN) { + if (PN->getNumIncomingValues() == 2) // The loops have been canonicalized. + if (const Loop *L = LI->getLoopFor(PN->getParent())) + if (L->getHeader() == PN->getParent()) { + // If it lives in the loop header, it has two incoming values, one + // from outside the loop, and one from inside. + unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0)); + unsigned BackEdge = IncomingEdge^1; + + // While we are analyzing this PHI node, handle its value symbolically. + SCEVHandle SymbolicName = getUnknown(PN); + assert(Scalars.find(PN) == Scalars.end() && + "PHI node already processed?"); + Scalars.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName)); + + // Using this symbolic name for the PHI, analyze the value coming around + // the back-edge. + SCEVHandle BEValue = getSCEV(PN->getIncomingValue(BackEdge)); + + // NOTE: If BEValue is loop invariant, we know that the PHI node just + // has a special value for the first iteration of the loop. + + // If the value coming around the backedge is an add with the symbolic + // value we just inserted, then we found a simple induction variable! + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) { + // If there is a single occurrence of the symbolic value, replace it + // with a recurrence. + unsigned FoundIndex = Add->getNumOperands(); + for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) + if (Add->getOperand(i) == SymbolicName) + if (FoundIndex == e) { + FoundIndex = i; + break; + } + + if (FoundIndex != Add->getNumOperands()) { + // Create an add with everything but the specified operand. + std::vector<SCEVHandle> Ops; + for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) + if (i != FoundIndex) + Ops.push_back(Add->getOperand(i)); + SCEVHandle Accum = getAddExpr(Ops); + + // This is not a valid addrec if the step amount is varying each + // loop iteration, but is not itself an addrec in this loop. + if (Accum->isLoopInvariant(L) || + (isa<SCEVAddRecExpr>(Accum) && + cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) { + SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge)); + SCEVHandle PHISCEV = getAddRecExpr(StartVal, Accum, L); + + // Okay, for the entire analysis of this edge we assumed the PHI + // to be symbolic. We now need to go back and update all of the + // entries for the scalars that use the PHI (except for the PHI + // itself) to use the new analyzed value instead of the "symbolic" + // value. + ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV); + return PHISCEV; + } + } + } else if (const SCEVAddRecExpr *AddRec = + dyn_cast<SCEVAddRecExpr>(BEValue)) { + // Otherwise, this could be a loop like this: + // i = 0; for (j = 1; ..; ++j) { .... i = j; } + // In this case, j = {1,+,1} and BEValue is j. + // Because the other in-value of i (0) fits the evolution of BEValue + // i really is an addrec evolution. + if (AddRec->getLoop() == L && AddRec->isAffine()) { + SCEVHandle StartVal = getSCEV(PN->getIncomingValue(IncomingEdge)); + + // If StartVal = j.start - j.stride, we can use StartVal as the + // initial step of the addrec evolution. + if (StartVal == getMinusSCEV(AddRec->getOperand(0), + AddRec->getOperand(1))) { + SCEVHandle PHISCEV = + getAddRecExpr(StartVal, AddRec->getOperand(1), L); + + // Okay, for the entire analysis of this edge we assumed the PHI + // to be symbolic. We now need to go back and update all of the + // entries for the scalars that use the PHI (except for the PHI + // itself) to use the new analyzed value instead of the "symbolic" + // value. + ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV); + return PHISCEV; + } + } + } + + return SymbolicName; + } + + // If it's not a loop phi, we can't handle it yet. + return getUnknown(PN); +} + +/// createNodeForGEP - Expand GEP instructions into add and multiply +/// operations. This allows them to be analyzed by regular SCEV code. +/// +SCEVHandle ScalarEvolution::createNodeForGEP(User *GEP) { + + const Type *IntPtrTy = TD->getIntPtrType(); + Value *Base = GEP->getOperand(0); + // Don't attempt to analyze GEPs over unsized objects. + if (!cast<PointerType>(Base->getType())->getElementType()->isSized()) + return getUnknown(GEP); + SCEVHandle TotalOffset = getIntegerSCEV(0, IntPtrTy); + gep_type_iterator GTI = gep_type_begin(GEP); + for (GetElementPtrInst::op_iterator I = next(GEP->op_begin()), + E = GEP->op_end(); + I != E; ++I) { + Value *Index = *I; + // Compute the (potentially symbolic) offset in bytes for this index. + if (const StructType *STy = dyn_cast<StructType>(*GTI++)) { + // For a struct, add the member offset. + const StructLayout &SL = *TD->getStructLayout(STy); + unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue(); + uint64_t Offset = SL.getElementOffset(FieldNo); + TotalOffset = getAddExpr(TotalOffset, + getIntegerSCEV(Offset, IntPtrTy)); + } else { + // For an array, add the element offset, explicitly scaled. + SCEVHandle LocalOffset = getSCEV(Index); + if (!isa<PointerType>(LocalOffset->getType())) + // Getelementptr indicies are signed. + LocalOffset = getTruncateOrSignExtend(LocalOffset, + IntPtrTy); + LocalOffset = + getMulExpr(LocalOffset, + getIntegerSCEV(TD->getTypeAllocSize(*GTI), + IntPtrTy)); + TotalOffset = getAddExpr(TotalOffset, LocalOffset); + } + } + return getAddExpr(getSCEV(Base), TotalOffset); +} + +/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is +/// guaranteed to end in (at every loop iteration). It is, at the same time, +/// the minimum number of times S is divisible by 2. For example, given {4,+,8} +/// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S. +static uint32_t GetMinTrailingZeros(SCEVHandle S, const ScalarEvolution &SE) { + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) + return C->getValue()->getValue().countTrailingZeros(); + + if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S)) + return std::min(GetMinTrailingZeros(T->getOperand(), SE), + (uint32_t)SE.getTypeSizeInBits(T->getType())); + + if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) { + uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), SE); + return OpRes == SE.getTypeSizeInBits(E->getOperand()->getType()) ? + SE.getTypeSizeInBits(E->getType()) : OpRes; + } + + if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) { + uint32_t OpRes = GetMinTrailingZeros(E->getOperand(), SE); + return OpRes == SE.getTypeSizeInBits(E->getOperand()->getType()) ? + SE.getTypeSizeInBits(E->getType()) : OpRes; + } + + if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0), SE); + for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), SE)); + return MinOpRes; + } + + if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) { + // The result is the sum of all operands results. + uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0), SE); + uint32_t BitWidth = SE.getTypeSizeInBits(M->getType()); + for (unsigned i = 1, e = M->getNumOperands(); + SumOpRes != BitWidth && i != e; ++i) + SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i), SE), + BitWidth); + return SumOpRes; + } + + if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0), SE); + for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i), SE)); + return MinOpRes; + } + + if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0), SE); + for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), SE)); + return MinOpRes; + } + + if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) { + // The result is the min of all operands results. + uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0), SE); + for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) + MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i), SE)); + return MinOpRes; + } + + // SCEVUDivExpr, SCEVUnknown + return 0; +} + +/// createSCEV - We know that there is no SCEV for the specified value. +/// Analyze the expression. +/// +SCEVHandle ScalarEvolution::createSCEV(Value *V) { + if (!isSCEVable(V->getType())) + return getUnknown(V); + + unsigned Opcode = Instruction::UserOp1; + if (Instruction *I = dyn_cast<Instruction>(V)) + Opcode = I->getOpcode(); + else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) + Opcode = CE->getOpcode(); + else + return getUnknown(V); + + User *U = cast<User>(V); + switch (Opcode) { + case Instruction::Add: + return getAddExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::Mul: + return getMulExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::UDiv: + return getUDivExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::Sub: + return getMinusSCEV(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + case Instruction::And: + // For an expression like x&255 that merely masks off the high bits, + // use zext(trunc(x)) as the SCEV expression. + if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { + if (CI->isNullValue()) + return getSCEV(U->getOperand(1)); + if (CI->isAllOnesValue()) + return getSCEV(U->getOperand(0)); + const APInt &A = CI->getValue(); + unsigned Ones = A.countTrailingOnes(); + if (APIntOps::isMask(Ones, A)) + return + getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)), + IntegerType::get(Ones)), + U->getType()); + } + break; + case Instruction::Or: + // If the RHS of the Or is a constant, we may have something like: + // X*4+1 which got turned into X*4|1. Handle this as an Add so loop + // optimizations will transparently handle this case. + // + // In order for this transformation to be safe, the LHS must be of the + // form X*(2^n) and the Or constant must be less than 2^n. + if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { + SCEVHandle LHS = getSCEV(U->getOperand(0)); + const APInt &CIVal = CI->getValue(); + if (GetMinTrailingZeros(LHS, *this) >= + (CIVal.getBitWidth() - CIVal.countLeadingZeros())) + return getAddExpr(LHS, getSCEV(U->getOperand(1))); + } + break; + case Instruction::Xor: + if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { + // If the RHS of the xor is a signbit, then this is just an add. + // Instcombine turns add of signbit into xor as a strength reduction step. + if (CI->getValue().isSignBit()) + return getAddExpr(getSCEV(U->getOperand(0)), + getSCEV(U->getOperand(1))); + + // If the RHS of xor is -1, then this is a not operation. + if (CI->isAllOnesValue()) + return getNotSCEV(getSCEV(U->getOperand(0))); + + // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask. + // This is a variant of the check for xor with -1, and it handles + // the case where instcombine has trimmed non-demanded bits out + // of an xor with -1. + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0))) + if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1))) + if (BO->getOpcode() == Instruction::And && + LCI->getValue() == CI->getValue()) + if (const SCEVZeroExtendExpr *Z = + dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) + return getZeroExtendExpr(getNotSCEV(Z->getOperand()), + U->getType()); + } + break; + + case Instruction::Shl: + // Turn shift left of a constant amount into a multiply. + if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) { + uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth(); + Constant *X = ConstantInt::get( + APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); + return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X)); + } + break; + + case Instruction::LShr: + // Turn logical shift right of a constant into a unsigned divide. + if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) { + uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth(); + Constant *X = ConstantInt::get( + APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth))); + return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X)); + } + break; + + case Instruction::AShr: + // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. + if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) + if (Instruction *L = dyn_cast<Instruction>(U->getOperand(0))) + if (L->getOpcode() == Instruction::Shl && + L->getOperand(1) == U->getOperand(1)) { + unsigned BitWidth = getTypeSizeInBits(U->getType()); + uint64_t Amt = BitWidth - CI->getZExtValue(); + if (Amt == BitWidth) + return getSCEV(L->getOperand(0)); // shift by zero --> noop + if (Amt > BitWidth) + return getIntegerSCEV(0, U->getType()); // value is undefined + return + getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)), + IntegerType::get(Amt)), + U->getType()); + } + break; + + case Instruction::Trunc: + return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType()); + + case Instruction::ZExt: + return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType()); + + case Instruction::SExt: + return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType()); + + case Instruction::BitCast: + // BitCasts are no-op casts so we just eliminate the cast. + if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) + return getSCEV(U->getOperand(0)); + break; + + case Instruction::IntToPtr: + if (!TD) break; // Without TD we can't analyze pointers. + return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)), + TD->getIntPtrType()); + + case Instruction::PtrToInt: + if (!TD) break; // Without TD we can't analyze pointers. + return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)), + U->getType()); + + case Instruction::GetElementPtr: + if (!TD) break; // Without TD we can't analyze pointers. + return createNodeForGEP(U); + + case Instruction::PHI: + return createNodeForPHI(cast<PHINode>(U)); + + case Instruction::Select: + // This could be a smax or umax that was lowered earlier. + // Try to recover it. + if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) { + Value *LHS = ICI->getOperand(0); + Value *RHS = ICI->getOperand(1); + switch (ICI->getPredicate()) { + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + if (LHS == U->getOperand(1) && RHS == U->getOperand(2)) + return getSMaxExpr(getSCEV(LHS), getSCEV(RHS)); + else if (LHS == U->getOperand(2) && RHS == U->getOperand(1)) + // ~smax(~x, ~y) == smin(x, y). + return getNotSCEV(getSMaxExpr( + getNotSCEV(getSCEV(LHS)), + getNotSCEV(getSCEV(RHS)))); + break; + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + if (LHS == U->getOperand(1) && RHS == U->getOperand(2)) + return getUMaxExpr(getSCEV(LHS), getSCEV(RHS)); + else if (LHS == U->getOperand(2) && RHS == U->getOperand(1)) + // ~umax(~x, ~y) == umin(x, y) + return getNotSCEV(getUMaxExpr(getNotSCEV(getSCEV(LHS)), + getNotSCEV(getSCEV(RHS)))); + break; + default: + break; + } + } + + default: // We cannot analyze this expression. + break; + } + + return getUnknown(V); +} + + + +//===----------------------------------------------------------------------===// +// Iteration Count Computation Code +// + +/// getBackedgeTakenCount - If the specified loop has a predictable +/// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute +/// object. The backedge-taken count is the number of times the loop header +/// will be branched to from within the loop. This is one less than the +/// trip count of the loop, since it doesn't count the first iteration, +/// when the header is branched to from outside the loop. +/// +/// Note that it is not valid to call this method on a loop without a +/// loop-invariant backedge-taken count (see +/// hasLoopInvariantBackedgeTakenCount). +/// +SCEVHandle ScalarEvolution::getBackedgeTakenCount(const Loop *L) { + return getBackedgeTakenInfo(L).Exact; +} + +/// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except +/// return the least SCEV value that is known never to be less than the +/// actual backedge taken count. +SCEVHandle ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) { + return getBackedgeTakenInfo(L).Max; +} + +const ScalarEvolution::BackedgeTakenInfo & +ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { + // Initially insert a CouldNotCompute for this loop. If the insertion + // succeeds, procede to actually compute a backedge-taken count and + // update the value. The temporary CouldNotCompute value tells SCEV + // code elsewhere that it shouldn't attempt to request a new + // backedge-taken count, which could result in infinite recursion. + std::pair<std::map<const Loop*, BackedgeTakenInfo>::iterator, bool> Pair = + BackedgeTakenCounts.insert(std::make_pair(L, getCouldNotCompute())); + if (Pair.second) { + BackedgeTakenInfo ItCount = ComputeBackedgeTakenCount(L); + if (ItCount.Exact != UnknownValue) { + assert(ItCount.Exact->isLoopInvariant(L) && + ItCount.Max->isLoopInvariant(L) && + "Computed trip count isn't loop invariant for loop!"); + ++NumTripCountsComputed; + + // Update the value in the map. + Pair.first->second = ItCount; + } else if (isa<PHINode>(L->getHeader()->begin())) { + // Only count loops that have phi nodes as not being computable. + ++NumTripCountsNotComputed; + } + + // Now that we know more about the trip count for this loop, forget any + // existing SCEV values for PHI nodes in this loop since they are only + // conservative estimates made without the benefit + // of trip count information. + if (ItCount.hasAnyInfo()) + forgetLoopPHIs(L); + } + return Pair.first->second; +} + +/// forgetLoopBackedgeTakenCount - This method should be called by the +/// client when it has changed a loop in a way that may effect +/// ScalarEvolution's ability to compute a trip count, or if the loop +/// is deleted. +void ScalarEvolution::forgetLoopBackedgeTakenCount(const Loop *L) { + BackedgeTakenCounts.erase(L); + forgetLoopPHIs(L); +} + +/// forgetLoopPHIs - Delete the memoized SCEVs associated with the +/// PHI nodes in the given loop. This is used when the trip count of +/// the loop may have changed. +void ScalarEvolution::forgetLoopPHIs(const Loop *L) { + BasicBlock *Header = L->getHeader(); + + // Push all Loop-header PHIs onto the Worklist stack, except those + // that are presently represented via a SCEVUnknown. SCEVUnknown for + // a PHI either means that it has an unrecognized structure, or it's + // a PHI that's in the progress of being computed by createNodeForPHI. + // In the former case, additional loop trip count information isn't + // going to change anything. In the later case, createNodeForPHI will + // perform the necessary updates on its own when it gets to that point. + SmallVector<Instruction *, 16> Worklist; + for (BasicBlock::iterator I = Header->begin(); + PHINode *PN = dyn_cast<PHINode>(I); ++I) { + std::map<SCEVCallbackVH, SCEVHandle>::iterator It = Scalars.find((Value*)I); + if (It != Scalars.end() && !isa<SCEVUnknown>(It->second)) + Worklist.push_back(PN); + } + + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + if (Scalars.erase(I)) + for (Value::use_iterator UI = I->use_begin(), UE = I->use_end(); + UI != UE; ++UI) + Worklist.push_back(cast<Instruction>(UI)); + } +} + +/// ComputeBackedgeTakenCount - Compute the number of times the backedge +/// of the specified loop will execute. +ScalarEvolution::BackedgeTakenInfo +ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { + // If the loop has a non-one exit block count, we can't analyze it. + SmallVector<BasicBlock*, 8> ExitBlocks; + L->getExitBlocks(ExitBlocks); + if (ExitBlocks.size() != 1) return UnknownValue; + + // Okay, there is one exit block. Try to find the condition that causes the + // loop to be exited. + BasicBlock *ExitBlock = ExitBlocks[0]; + + BasicBlock *ExitingBlock = 0; + for (pred_iterator PI = pred_begin(ExitBlock), E = pred_end(ExitBlock); + PI != E; ++PI) + if (L->contains(*PI)) { + if (ExitingBlock == 0) + ExitingBlock = *PI; + else + return UnknownValue; // More than one block exiting! + } + assert(ExitingBlock && "No exits from loop, something is broken!"); + + // Okay, we've computed the exiting block. See what condition causes us to + // exit. + // + // FIXME: we should be able to handle switch instructions (with a single exit) + BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator()); + if (ExitBr == 0) return UnknownValue; + assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!"); + + // At this point, we know we have a conditional branch that determines whether + // the loop is exited. However, we don't know if the branch is executed each + // time through the loop. If not, then the execution count of the branch will + // not be equal to the trip count of the loop. + // + // Currently we check for this by checking to see if the Exit branch goes to + // the loop header. If so, we know it will always execute the same number of + // times as the loop. We also handle the case where the exit block *is* the + // loop header. This is common for un-rotated loops. More extensive analysis + // could be done to handle more cases here. + if (ExitBr->getSuccessor(0) != L->getHeader() && + ExitBr->getSuccessor(1) != L->getHeader() && + ExitBr->getParent() != L->getHeader()) + return UnknownValue; + + ICmpInst *ExitCond = dyn_cast<ICmpInst>(ExitBr->getCondition()); + + // If it's not an integer or pointer comparison then compute it the hard way. + if (ExitCond == 0) + return ComputeBackedgeTakenCountExhaustively(L, ExitBr->getCondition(), + ExitBr->getSuccessor(0) == ExitBlock); + + // If the condition was exit on true, convert the condition to exit on false + ICmpInst::Predicate Cond; + if (ExitBr->getSuccessor(1) == ExitBlock) + Cond = ExitCond->getPredicate(); + else + Cond = ExitCond->getInversePredicate(); + + // Handle common loops like: for (X = "string"; *X; ++X) + if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0))) + if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) { + SCEVHandle ItCnt = + ComputeLoadConstantCompareBackedgeTakenCount(LI, RHS, L, Cond); + if (!isa<SCEVCouldNotCompute>(ItCnt)) return ItCnt; + } + + SCEVHandle LHS = getSCEV(ExitCond->getOperand(0)); + SCEVHandle RHS = getSCEV(ExitCond->getOperand(1)); + + // Try to evaluate any dependencies out of the loop. + LHS = getSCEVAtScope(LHS, L); + RHS = getSCEVAtScope(RHS, L); + + // At this point, we would like to compute how many iterations of the + // loop the predicate will return true for these inputs. + if (LHS->isLoopInvariant(L) && !RHS->isLoopInvariant(L)) { + // If there is a loop-invariant, force it into the RHS. + std::swap(LHS, RHS); + Cond = ICmpInst::getSwappedPredicate(Cond); + } + + // If we have a comparison of a chrec against a constant, try to use value + // ranges to answer this query. + if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) + if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS)) + if (AddRec->getLoop() == L) { + // Form the constant range. + ConstantRange CompRange( + ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue())); + + SCEVHandle Ret = AddRec->getNumIterationsInRange(CompRange, *this); + if (!isa<SCEVCouldNotCompute>(Ret)) return Ret; + } + + switch (Cond) { + case ICmpInst::ICMP_NE: { // while (X != Y) + // Convert to: while (X-Y != 0) + SCEVHandle TC = HowFarToZero(getMinusSCEV(LHS, RHS), L); + if (!isa<SCEVCouldNotCompute>(TC)) return TC; + break; + } + case ICmpInst::ICMP_EQ: { + // Convert to: while (X-Y == 0) // while (X == Y) + SCEVHandle TC = HowFarToNonZero(getMinusSCEV(LHS, RHS), L); + if (!isa<SCEVCouldNotCompute>(TC)) return TC; + break; + } + case ICmpInst::ICMP_SLT: { + BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, true); + if (BTI.hasAnyInfo()) return BTI; + break; + } + case ICmpInst::ICMP_SGT: { + BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS), + getNotSCEV(RHS), L, true); + if (BTI.hasAnyInfo()) return BTI; + break; + } + case ICmpInst::ICMP_ULT: { + BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, false); + if (BTI.hasAnyInfo()) return BTI; + break; + } + case ICmpInst::ICMP_UGT: { + BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS), + getNotSCEV(RHS), L, false); + if (BTI.hasAnyInfo()) return BTI; + break; + } + default: +#if 0 + errs() << "ComputeBackedgeTakenCount "; + if (ExitCond->getOperand(0)->getType()->isUnsigned()) + errs() << "[unsigned] "; + errs() << *LHS << " " + << Instruction::getOpcodeName(Instruction::ICmp) + << " " << *RHS << "\n"; +#endif + break; + } + return + ComputeBackedgeTakenCountExhaustively(L, ExitCond, + ExitBr->getSuccessor(0) == ExitBlock); +} + +static ConstantInt * +EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, + ScalarEvolution &SE) { + SCEVHandle InVal = SE.getConstant(C); + SCEVHandle Val = AddRec->evaluateAtIteration(InVal, SE); + assert(isa<SCEVConstant>(Val) && + "Evaluation of SCEV at constant didn't fold correctly?"); + return cast<SCEVConstant>(Val)->getValue(); +} + +/// GetAddressedElementFromGlobal - Given a global variable with an initializer +/// and a GEP expression (missing the pointer index) indexing into it, return +/// the addressed element of the initializer or null if the index expression is +/// invalid. +static Constant * +GetAddressedElementFromGlobal(GlobalVariable *GV, + const std::vector<ConstantInt*> &Indices) { + Constant *Init = GV->getInitializer(); + for (unsigned i = 0, e = Indices.size(); i != e; ++i) { + uint64_t Idx = Indices[i]->getZExtValue(); + if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Init)) { + assert(Idx < CS->getNumOperands() && "Bad struct index!"); + Init = cast<Constant>(CS->getOperand(Idx)); + } else if (ConstantArray *CA = dyn_cast<ConstantArray>(Init)) { + if (Idx >= CA->getNumOperands()) return 0; // Bogus program + Init = cast<Constant>(CA->getOperand(Idx)); + } else if (isa<ConstantAggregateZero>(Init)) { + if (const StructType *STy = dyn_cast<StructType>(Init->getType())) { + assert(Idx < STy->getNumElements() && "Bad struct index!"); + Init = Constant::getNullValue(STy->getElementType(Idx)); + } else if (const ArrayType *ATy = dyn_cast<ArrayType>(Init->getType())) { + if (Idx >= ATy->getNumElements()) return 0; // Bogus program + Init = Constant::getNullValue(ATy->getElementType()); + } else { + assert(0 && "Unknown constant aggregate type!"); + } + return 0; + } else { + return 0; // Unknown initializer type + } + } + return Init; +} + +/// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition of +/// 'icmp op load X, cst', try to see if we can compute the backedge +/// execution count. +SCEVHandle ScalarEvolution:: +ComputeLoadConstantCompareBackedgeTakenCount(LoadInst *LI, Constant *RHS, + const Loop *L, + ICmpInst::Predicate predicate) { + if (LI->isVolatile()) return UnknownValue; + + // Check to see if the loaded pointer is a getelementptr of a global. + GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0)); + if (!GEP) return UnknownValue; + + // Make sure that it is really a constant global we are gepping, with an + // initializer, and make sure the first IDX is really 0. + GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)); + if (!GV || !GV->isConstant() || !GV->hasInitializer() || + GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) || + !cast<Constant>(GEP->getOperand(1))->isNullValue()) + return UnknownValue; + + // Okay, we allow one non-constant index into the GEP instruction. + Value *VarIdx = 0; + std::vector<ConstantInt*> Indexes; + unsigned VarIdxNum = 0; + for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i) + if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) { + Indexes.push_back(CI); + } else if (!isa<ConstantInt>(GEP->getOperand(i))) { + if (VarIdx) return UnknownValue; // Multiple non-constant idx's. + VarIdx = GEP->getOperand(i); + VarIdxNum = i-2; + Indexes.push_back(0); + } + + // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant. + // Check to see if X is a loop variant variable value now. + SCEVHandle Idx = getSCEV(VarIdx); + Idx = getSCEVAtScope(Idx, L); + + // We can only recognize very limited forms of loop index expressions, in + // particular, only affine AddRec's like {C1,+,C2}. + const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx); + if (!IdxExpr || !IdxExpr->isAffine() || IdxExpr->isLoopInvariant(L) || + !isa<SCEVConstant>(IdxExpr->getOperand(0)) || + !isa<SCEVConstant>(IdxExpr->getOperand(1))) + return UnknownValue; + + unsigned MaxSteps = MaxBruteForceIterations; + for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) { + ConstantInt *ItCst = + ConstantInt::get(IdxExpr->getType(), IterationNum); + ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this); + + // Form the GEP offset. + Indexes[VarIdxNum] = Val; + + Constant *Result = GetAddressedElementFromGlobal(GV, Indexes); + if (Result == 0) break; // Cannot compute! + + // Evaluate the condition for this iteration. + Result = ConstantExpr::getICmp(predicate, Result, RHS); + if (!isa<ConstantInt>(Result)) break; // Couldn't decide for sure + if (cast<ConstantInt>(Result)->getValue().isMinValue()) { +#if 0 + errs() << "\n***\n*** Computed loop count " << *ItCst + << "\n*** From global " << *GV << "*** BB: " << *L->getHeader() + << "***\n"; +#endif + ++NumArrayLenItCounts; + return getConstant(ItCst); // Found terminating iteration! + } + } + return UnknownValue; +} + + +/// CanConstantFold - Return true if we can constant fold an instruction of the +/// specified type, assuming that all operands were constants. +static bool CanConstantFold(const Instruction *I) { + if (isa<BinaryOperator>(I) || isa<CmpInst>(I) || + isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I)) + return true; + + if (const CallInst *CI = dyn_cast<CallInst>(I)) + if (const Function *F = CI->getCalledFunction()) + return canConstantFoldCallTo(F); + return false; +} + +/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node +/// in the loop that V is derived from. We allow arbitrary operations along the +/// way, but the operands of an operation must either be constants or a value +/// derived from a constant PHI. If this expression does not fit with these +/// constraints, return null. +static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { + // If this is not an instruction, or if this is an instruction outside of the + // loop, it can't be derived from a loop PHI. + Instruction *I = dyn_cast<Instruction>(V); + if (I == 0 || !L->contains(I->getParent())) return 0; + + if (PHINode *PN = dyn_cast<PHINode>(I)) { + if (L->getHeader() == I->getParent()) + return PN; + else + // We don't currently keep track of the control flow needed to evaluate + // PHIs, so we cannot handle PHIs inside of loops. + return 0; + } + + // If we won't be able to constant fold this expression even if the operands + // are constants, return early. + if (!CanConstantFold(I)) return 0; + + // Otherwise, we can evaluate this instruction if all of its operands are + // constant or derived from a PHI node themselves. + PHINode *PHI = 0; + for (unsigned Op = 0, e = I->getNumOperands(); Op != e; ++Op) + if (!(isa<Constant>(I->getOperand(Op)) || + isa<GlobalValue>(I->getOperand(Op)))) { + PHINode *P = getConstantEvolvingPHI(I->getOperand(Op), L); + if (P == 0) return 0; // Not evolving from PHI + if (PHI == 0) + PHI = P; + else if (PHI != P) + return 0; // Evolving from multiple different PHIs. + } + + // This is a expression evolving from a constant PHI! + return PHI; +} + +/// EvaluateExpression - Given an expression that passes the +/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node +/// in the loop has the value PHIVal. If we can't fold this expression for some +/// reason, return null. +static Constant *EvaluateExpression(Value *V, Constant *PHIVal) { + if (isa<PHINode>(V)) return PHIVal; + if (Constant *C = dyn_cast<Constant>(V)) return C; + if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) return GV; + Instruction *I = cast<Instruction>(V); + + std::vector<Constant*> Operands; + Operands.resize(I->getNumOperands()); + + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { + Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal); + if (Operands[i] == 0) return 0; + } + + if (const CmpInst *CI = dyn_cast<CmpInst>(I)) + return ConstantFoldCompareInstOperands(CI->getPredicate(), + &Operands[0], Operands.size()); + else + return ConstantFoldInstOperands(I->getOpcode(), I->getType(), + &Operands[0], Operands.size()); +} + +/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is +/// in the header of its containing loop, we know the loop executes a +/// constant number of times, and the PHI node is just a recurrence +/// involving constants, fold it. +Constant *ScalarEvolution:: +getConstantEvolutionLoopExitValue(PHINode *PN, const APInt& BEs, const Loop *L){ + std::map<PHINode*, Constant*>::iterator I = + ConstantEvolutionLoopExitValue.find(PN); + if (I != ConstantEvolutionLoopExitValue.end()) + return I->second; + + if (BEs.ugt(APInt(BEs.getBitWidth(),MaxBruteForceIterations))) + return ConstantEvolutionLoopExitValue[PN] = 0; // Not going to evaluate it. + + Constant *&RetVal = ConstantEvolutionLoopExitValue[PN]; + + // Since the loop is canonicalized, the PHI node must have two entries. One + // entry must be a constant (coming in from outside of the loop), and the + // second must be derived from the same PHI. + bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1)); + Constant *StartCST = + dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge)); + if (StartCST == 0) + return RetVal = 0; // Must be a constant. + + Value *BEValue = PN->getIncomingValue(SecondIsBackedge); + PHINode *PN2 = getConstantEvolvingPHI(BEValue, L); + if (PN2 != PN) + return RetVal = 0; // Not derived from same PHI. + + // Execute the loop symbolically to determine the exit value. + if (BEs.getActiveBits() >= 32) + return RetVal = 0; // More than 2^32-1 iterations?? Not doing it! + + unsigned NumIterations = BEs.getZExtValue(); // must be in range + unsigned IterationNum = 0; + for (Constant *PHIVal = StartCST; ; ++IterationNum) { + if (IterationNum == NumIterations) + return RetVal = PHIVal; // Got exit value! + + // Compute the value of the PHI node for the next iteration. + Constant *NextPHI = EvaluateExpression(BEValue, PHIVal); + if (NextPHI == PHIVal) + return RetVal = NextPHI; // Stopped evolving! + if (NextPHI == 0) + return 0; // Couldn't evaluate! + PHIVal = NextPHI; + } +} + +/// ComputeBackedgeTakenCountExhaustively - If the trip is known to execute a +/// constant number of times (the condition evolves only from constants), +/// try to evaluate a few iterations of the loop until we get the exit +/// condition gets a value of ExitWhen (true or false). If we cannot +/// evaluate the trip count of the loop, return UnknownValue. +SCEVHandle ScalarEvolution:: +ComputeBackedgeTakenCountExhaustively(const Loop *L, Value *Cond, bool ExitWhen) { + PHINode *PN = getConstantEvolvingPHI(Cond, L); + if (PN == 0) return UnknownValue; + + // Since the loop is canonicalized, the PHI node must have two entries. One + // entry must be a constant (coming in from outside of the loop), and the + // second must be derived from the same PHI. + bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1)); + Constant *StartCST = + dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge)); + if (StartCST == 0) return UnknownValue; // Must be a constant. + + Value *BEValue = PN->getIncomingValue(SecondIsBackedge); + PHINode *PN2 = getConstantEvolvingPHI(BEValue, L); + if (PN2 != PN) return UnknownValue; // Not derived from same PHI. + + // Okay, we find a PHI node that defines the trip count of this loop. Execute + // the loop symbolically to determine when the condition gets a value of + // "ExitWhen". + unsigned IterationNum = 0; + unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis. + for (Constant *PHIVal = StartCST; + IterationNum != MaxIterations; ++IterationNum) { + ConstantInt *CondVal = + dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, PHIVal)); + + // Couldn't symbolically evaluate. + if (!CondVal) return UnknownValue; + + if (CondVal->getValue() == uint64_t(ExitWhen)) { + ConstantEvolutionLoopExitValue[PN] = PHIVal; + ++NumBruteForceTripCountsComputed; + return getConstant(ConstantInt::get(Type::Int32Ty, IterationNum)); + } + + // Compute the value of the PHI node for the next iteration. + Constant *NextPHI = EvaluateExpression(BEValue, PHIVal); + if (NextPHI == 0 || NextPHI == PHIVal) + return UnknownValue; // Couldn't evaluate or not making progress... + PHIVal = NextPHI; + } + + // Too many iterations were needed to evaluate. + return UnknownValue; +} + +/// getSCEVAtScope - Return a SCEV expression handle for the specified value +/// at the specified scope in the program. The L value specifies a loop +/// nest to evaluate the expression at, where null is the top-level or a +/// specified loop is immediately inside of the loop. +/// +/// This method can be used to compute the exit value for a variable defined +/// in a loop by querying what the value will hold in the parent loop. +/// +/// In the case that a relevant loop exit value cannot be computed, the +/// original value V is returned. +SCEVHandle ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { + // FIXME: this should be turned into a virtual method on SCEV! + + if (isa<SCEVConstant>(V)) return V; + + // If this instruction is evolved from a constant-evolving PHI, compute the + // exit value from the loop without using SCEVs. + if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) { + if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) { + const Loop *LI = (*this->LI)[I->getParent()]; + if (LI && LI->getParentLoop() == L) // Looking for loop exit value. + if (PHINode *PN = dyn_cast<PHINode>(I)) + if (PN->getParent() == LI->getHeader()) { + // Okay, there is no closed form solution for the PHI node. Check + // to see if the loop that contains it has a known backedge-taken + // count. If so, we may be able to force computation of the exit + // value. + SCEVHandle BackedgeTakenCount = getBackedgeTakenCount(LI); + if (const SCEVConstant *BTCC = + dyn_cast<SCEVConstant>(BackedgeTakenCount)) { + // Okay, we know how many times the containing loop executes. If + // this is a constant evolving PHI node, get the final value at + // the specified iteration number. + Constant *RV = getConstantEvolutionLoopExitValue(PN, + BTCC->getValue()->getValue(), + LI); + if (RV) return getUnknown(RV); + } + } + + // Okay, this is an expression that we cannot symbolically evaluate + // into a SCEV. Check to see if it's possible to symbolically evaluate + // the arguments into constants, and if so, try to constant propagate the + // result. This is particularly useful for computing loop exit values. + if (CanConstantFold(I)) { + // Check to see if we've folded this instruction at this loop before. + std::map<const Loop *, Constant *> &Values = ValuesAtScopes[I]; + std::pair<std::map<const Loop *, Constant *>::iterator, bool> Pair = + Values.insert(std::make_pair(L, static_cast<Constant *>(0))); + if (!Pair.second) + return Pair.first->second ? &*getUnknown(Pair.first->second) : V; + + std::vector<Constant*> Operands; + Operands.reserve(I->getNumOperands()); + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { + Value *Op = I->getOperand(i); + if (Constant *C = dyn_cast<Constant>(Op)) { + Operands.push_back(C); + } else { + // If any of the operands is non-constant and if they are + // non-integer and non-pointer, don't even try to analyze them + // with scev techniques. + if (!isSCEVable(Op->getType())) + return V; + + SCEVHandle OpV = getSCEVAtScope(getSCEV(Op), L); + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OpV)) { + Constant *C = SC->getValue(); + if (C->getType() != Op->getType()) + C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false, + Op->getType(), + false), + C, Op->getType()); + Operands.push_back(C); + } else if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(OpV)) { + if (Constant *C = dyn_cast<Constant>(SU->getValue())) { + if (C->getType() != Op->getType()) + C = + ConstantExpr::getCast(CastInst::getCastOpcode(C, false, + Op->getType(), + false), + C, Op->getType()); + Operands.push_back(C); + } else + return V; + } else { + return V; + } + } + } + + Constant *C; + if (const CmpInst *CI = dyn_cast<CmpInst>(I)) + C = ConstantFoldCompareInstOperands(CI->getPredicate(), + &Operands[0], Operands.size()); + else + C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), + &Operands[0], Operands.size()); + Pair.first->second = C; + return getUnknown(C); + } + } + + // This is some other type of SCEVUnknown, just return it. + return V; + } + + if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) { + // Avoid performing the look-up in the common case where the specified + // expression has no loop-variant portions. + for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) { + SCEVHandle OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); + if (OpAtScope != Comm->getOperand(i)) { + // Okay, at least one of these operands is loop variant but might be + // foldable. Build a new instance of the folded commutative expression. + std::vector<SCEVHandle> NewOps(Comm->op_begin(), Comm->op_begin()+i); + NewOps.push_back(OpAtScope); + + for (++i; i != e; ++i) { + OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); + NewOps.push_back(OpAtScope); + } + if (isa<SCEVAddExpr>(Comm)) + return getAddExpr(NewOps); + if (isa<SCEVMulExpr>(Comm)) + return getMulExpr(NewOps); + if (isa<SCEVSMaxExpr>(Comm)) + return getSMaxExpr(NewOps); + if (isa<SCEVUMaxExpr>(Comm)) + return getUMaxExpr(NewOps); + assert(0 && "Unknown commutative SCEV type!"); + } + } + // If we got here, all operands are loop invariant. + return Comm; + } + + if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) { + SCEVHandle LHS = getSCEVAtScope(Div->getLHS(), L); + SCEVHandle RHS = getSCEVAtScope(Div->getRHS(), L); + if (LHS == Div->getLHS() && RHS == Div->getRHS()) + return Div; // must be loop invariant + return getUDivExpr(LHS, RHS); + } + + // If this is a loop recurrence for a loop that does not contain L, then we + // are dealing with the final value computed by the loop. + if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) { + if (!L || !AddRec->getLoop()->contains(L->getHeader())) { + // To evaluate this recurrence, we need to know how many times the AddRec + // loop iterates. Compute this now. + SCEVHandle BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); + if (BackedgeTakenCount == UnknownValue) return AddRec; + + // Then, evaluate the AddRec. + return AddRec->evaluateAtIteration(BackedgeTakenCount, *this); + } + return AddRec; + } + + if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) { + SCEVHandle Op = getSCEVAtScope(Cast->getOperand(), L); + if (Op == Cast->getOperand()) + return Cast; // must be loop invariant + return getZeroExtendExpr(Op, Cast->getType()); + } + + if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) { + SCEVHandle Op = getSCEVAtScope(Cast->getOperand(), L); + if (Op == Cast->getOperand()) + return Cast; // must be loop invariant + return getSignExtendExpr(Op, Cast->getType()); + } + + if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) { + SCEVHandle Op = getSCEVAtScope(Cast->getOperand(), L); + if (Op == Cast->getOperand()) + return Cast; // must be loop invariant + return getTruncateExpr(Op, Cast->getType()); + } + + assert(0 && "Unknown SCEV type!"); + return 0; +} + +/// getSCEVAtScope - This is a convenience function which does +/// getSCEVAtScope(getSCEV(V), L). +SCEVHandle ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { + return getSCEVAtScope(getSCEV(V), L); +} + +/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the +/// following equation: +/// +/// A * X = B (mod N) +/// +/// where N = 2^BW and BW is the common bit width of A and B. The signedness of +/// A and B isn't important. +/// +/// If the equation does not have a solution, SCEVCouldNotCompute is returned. +static SCEVHandle SolveLinEquationWithOverflow(const APInt &A, const APInt &B, + ScalarEvolution &SE) { + uint32_t BW = A.getBitWidth(); + assert(BW == B.getBitWidth() && "Bit widths must be the same."); + assert(A != 0 && "A must be non-zero."); + + // 1. D = gcd(A, N) + // + // The gcd of A and N may have only one prime factor: 2. The number of + // trailing zeros in A is its multiplicity + uint32_t Mult2 = A.countTrailingZeros(); + // D = 2^Mult2 + + // 2. Check if B is divisible by D. + // + // B is divisible by D if and only if the multiplicity of prime factor 2 for B + // is not less than multiplicity of this prime factor for D. + if (B.countTrailingZeros() < Mult2) + return SE.getCouldNotCompute(); + + // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic + // modulo (N / D). + // + // (N / D) may need BW+1 bits in its representation. Hence, we'll use this + // bit width during computations. + APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D + APInt Mod(BW + 1, 0); + Mod.set(BW - Mult2); // Mod = N / D + APInt I = AD.multiplicativeInverse(Mod); + + // 4. Compute the minimum unsigned root of the equation: + // I * (B / D) mod (N / D) + APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod); + + // The result is guaranteed to be less than 2^BW so we may truncate it to BW + // bits. + return SE.getConstant(Result.trunc(BW)); +} + +/// SolveQuadraticEquation - Find the roots of the quadratic equation for the +/// given quadratic chrec {L,+,M,+,N}. This returns either the two roots (which +/// might be the same) or two SCEVCouldNotCompute objects. +/// +static std::pair<SCEVHandle,SCEVHandle> +SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { + assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); + const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0)); + const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1)); + const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2)); + + // We currently can only solve this if the coefficients are constants. + if (!LC || !MC || !NC) { + const SCEV *CNC = SE.getCouldNotCompute(); + return std::make_pair(CNC, CNC); + } + + uint32_t BitWidth = LC->getValue()->getValue().getBitWidth(); + const APInt &L = LC->getValue()->getValue(); + const APInt &M = MC->getValue()->getValue(); + const APInt &N = NC->getValue()->getValue(); + APInt Two(BitWidth, 2); + APInt Four(BitWidth, 4); + + { + using namespace APIntOps; + const APInt& C = L; + // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C + // The B coefficient is M-N/2 + APInt B(M); + B -= sdiv(N,Two); + + // The A coefficient is N/2 + APInt A(N.sdiv(Two)); + + // Compute the B^2-4ac term. + APInt SqrtTerm(B); + SqrtTerm *= B; + SqrtTerm -= Four * (A * C); + + // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest + // integer value or else APInt::sqrt() will assert. + APInt SqrtVal(SqrtTerm.sqrt()); + + // Compute the two solutions for the quadratic formula. + // The divisions must be performed as signed divisions. + APInt NegB(-B); + APInt TwoA( A << 1 ); + if (TwoA.isMinValue()) { + const SCEV *CNC = SE.getCouldNotCompute(); + return std::make_pair(CNC, CNC); + } + + ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA)); + ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA)); + + return std::make_pair(SE.getConstant(Solution1), + SE.getConstant(Solution2)); + } // end APIntOps namespace +} + +/// HowFarToZero - Return the number of times a backedge comparing the specified +/// value to zero will execute. If not computable, return UnknownValue. +SCEVHandle ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) { + // If the value is a constant + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { + // If the value is already zero, the branch will execute zero times. + if (C->getValue()->isZero()) return C; + return UnknownValue; // Otherwise it will loop infinitely. + } + + const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V); + if (!AddRec || AddRec->getLoop() != L) + return UnknownValue; + + if (AddRec->isAffine()) { + // If this is an affine expression, the execution count of this branch is + // the minimum unsigned root of the following equation: + // + // Start + Step*N = 0 (mod 2^BW) + // + // equivalent to: + // + // Step*N = -Start (mod 2^BW) + // + // where BW is the common bit width of Start and Step. + + // Get the initial value for the loop. + SCEVHandle Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); + SCEVHandle Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); + + if (const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step)) { + // For now we handle only constant steps. + + // First, handle unitary steps. + if (StepC->getValue()->equalsInt(1)) // 1*N = -Start (mod 2^BW), so: + return getNegativeSCEV(Start); // N = -Start (as unsigned) + if (StepC->getValue()->isAllOnesValue()) // -1*N = -Start (mod 2^BW), so: + return Start; // N = Start (as unsigned) + + // Then, try to solve the above equation provided that Start is constant. + if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) + return SolveLinEquationWithOverflow(StepC->getValue()->getValue(), + -StartC->getValue()->getValue(), + *this); + } + } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) { + // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of + // the quadratic equation to solve it. + std::pair<SCEVHandle,SCEVHandle> Roots = SolveQuadraticEquation(AddRec, + *this); + const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first); + const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second); + if (R1) { +#if 0 + errs() << "HFTZ: " << *V << " - sol#1: " << *R1 + << " sol#2: " << *R2 << "\n"; +#endif + // Pick the smallest positive root value. + if (ConstantInt *CB = + dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, + R1->getValue(), R2->getValue()))) { + if (CB->getZExtValue() == false) + std::swap(R1, R2); // R1 is the minimum root now. + + // We can only use this value if the chrec ends up with an exact zero + // value at this index. When solving for "X*X != 5", for example, we + // should not accept a root of 2. + SCEVHandle Val = AddRec->evaluateAtIteration(R1, *this); + if (Val->isZero()) + return R1; // We found a quadratic root! + } + } + } + + return UnknownValue; +} + +/// HowFarToNonZero - Return the number of times a backedge checking the +/// specified value for nonzero will execute. If not computable, return +/// UnknownValue +SCEVHandle ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { + // Loops that look like: while (X == 0) are very strange indeed. We don't + // handle them yet except for the trivial case. This could be expanded in the + // future as needed. + + // If the value is a constant, check to see if it is known to be non-zero + // already. If so, the backedge will execute zero times. + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { + if (!C->getValue()->isNullValue()) + return getIntegerSCEV(0, C->getType()); + return UnknownValue; // Otherwise it will loop infinitely. + } + + // We could implement others, but I really doubt anyone writes loops like + // this, and if they did, they would already be constant folded. + return UnknownValue; +} + +/// getLoopPredecessor - If the given loop's header has exactly one unique +/// predecessor outside the loop, return it. Otherwise return null. +/// +BasicBlock *ScalarEvolution::getLoopPredecessor(const Loop *L) { + BasicBlock *Header = L->getHeader(); + BasicBlock *Pred = 0; + for (pred_iterator PI = pred_begin(Header), E = pred_end(Header); + PI != E; ++PI) + if (!L->contains(*PI)) { + if (Pred && Pred != *PI) return 0; // Multiple predecessors. + Pred = *PI; + } + return Pred; +} + +/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB +/// (which may not be an immediate predecessor) which has exactly one +/// successor from which BB is reachable, or null if no such block is +/// found. +/// +BasicBlock * +ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { + // If the block has a unique predecessor, then there is no path from the + // predecessor to the block that does not go through the direct edge + // from the predecessor to the block. + if (BasicBlock *Pred = BB->getSinglePredecessor()) + return Pred; + + // A loop's header is defined to be a block that dominates the loop. + // If the header has a unique predecessor outside the loop, it must be + // a block that has exactly one successor that can reach the loop. + if (Loop *L = LI->getLoopFor(BB)) + return getLoopPredecessor(L); + + return 0; +} + +/// isLoopGuardedByCond - Test whether entry to the loop is protected by +/// a conditional between LHS and RHS. This is used to help avoid max +/// expressions in loop trip counts. +bool ScalarEvolution::isLoopGuardedByCond(const Loop *L, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + // Interpret a null as meaning no loop, where there is obviously no guard + // (interprocedural conditions notwithstanding). + if (!L) return false; + + BasicBlock *Predecessor = getLoopPredecessor(L); + BasicBlock *PredecessorDest = L->getHeader(); + + // Starting at the loop predecessor, climb up the predecessor chain, as long + // as there are predecessors that can be found that have unique successors + // leading to the original header. + for (; Predecessor; + PredecessorDest = Predecessor, + Predecessor = getPredecessorWithUniqueSuccessorForBB(Predecessor)) { + + BranchInst *LoopEntryPredicate = + dyn_cast<BranchInst>(Predecessor->getTerminator()); + if (!LoopEntryPredicate || + LoopEntryPredicate->isUnconditional()) + continue; + + ICmpInst *ICI = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition()); + if (!ICI) continue; + + // Now that we found a conditional branch that dominates the loop, check to + // see if it is the comparison we are looking for. + Value *PreCondLHS = ICI->getOperand(0); + Value *PreCondRHS = ICI->getOperand(1); + ICmpInst::Predicate Cond; + if (LoopEntryPredicate->getSuccessor(0) == PredecessorDest) + Cond = ICI->getPredicate(); + else + Cond = ICI->getInversePredicate(); + + if (Cond == Pred) + ; // An exact match. + else if (!ICmpInst::isTrueWhenEqual(Cond) && Pred == ICmpInst::ICMP_NE) + ; // The actual condition is beyond sufficient. + else + // Check a few special cases. + switch (Cond) { + case ICmpInst::ICMP_UGT: + if (Pred == ICmpInst::ICMP_ULT) { + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_ULT; + break; + } + continue; + case ICmpInst::ICMP_SGT: + if (Pred == ICmpInst::ICMP_SLT) { + std::swap(PreCondLHS, PreCondRHS); + Cond = ICmpInst::ICMP_SLT; + break; + } + continue; + case ICmpInst::ICMP_NE: + // Expressions like (x >u 0) are often canonicalized to (x != 0), + // so check for this case by checking if the NE is comparing against + // a minimum or maximum constant. + if (!ICmpInst::isTrueWhenEqual(Pred)) + if (ConstantInt *CI = dyn_cast<ConstantInt>(PreCondRHS)) { + const APInt &A = CI->getValue(); + switch (Pred) { + case ICmpInst::ICMP_SLT: + if (A.isMaxSignedValue()) break; + continue; + case ICmpInst::ICMP_SGT: + if (A.isMinSignedValue()) break; + continue; + case ICmpInst::ICMP_ULT: + if (A.isMaxValue()) break; + continue; + case ICmpInst::ICMP_UGT: + if (A.isMinValue()) break; + continue; + default: + continue; + } + Cond = ICmpInst::ICMP_NE; + // NE is symmetric but the original comparison may not be. Swap + // the operands if necessary so that they match below. + if (isa<SCEVConstant>(LHS)) + std::swap(PreCondLHS, PreCondRHS); + break; + } + continue; + default: + // We weren't able to reconcile the condition. + continue; + } + + if (!PreCondLHS->getType()->isInteger()) continue; + + SCEVHandle PreCondLHSSCEV = getSCEV(PreCondLHS); + SCEVHandle PreCondRHSSCEV = getSCEV(PreCondRHS); + if ((LHS == PreCondLHSSCEV && RHS == PreCondRHSSCEV) || + (LHS == getNotSCEV(PreCondRHSSCEV) && + RHS == getNotSCEV(PreCondLHSSCEV))) + return true; + } + + return false; +} + +/// HowManyLessThans - Return the number of times a backedge containing the +/// specified less-than comparison will execute. If not computable, return +/// UnknownValue. +ScalarEvolution::BackedgeTakenInfo ScalarEvolution:: +HowManyLessThans(const SCEV *LHS, const SCEV *RHS, + const Loop *L, bool isSigned) { + // Only handle: "ADDREC < LoopInvariant". + if (!RHS->isLoopInvariant(L)) return UnknownValue; + + const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS); + if (!AddRec || AddRec->getLoop() != L) + return UnknownValue; + + if (AddRec->isAffine()) { + // FORNOW: We only support unit strides. + unsigned BitWidth = getTypeSizeInBits(AddRec->getType()); + SCEVHandle Step = AddRec->getStepRecurrence(*this); + SCEVHandle NegOne = getIntegerSCEV(-1, AddRec->getType()); + + // TODO: handle non-constant strides. + const SCEVConstant *CStep = dyn_cast<SCEVConstant>(Step); + if (!CStep || CStep->isZero()) + return UnknownValue; + if (CStep->isOne()) { + // With unit stride, the iteration never steps past the limit value. + } else if (CStep->getValue()->getValue().isStrictlyPositive()) { + if (const SCEVConstant *CLimit = dyn_cast<SCEVConstant>(RHS)) { + // Test whether a positive iteration iteration can step past the limit + // value and past the maximum value for its type in a single step. + if (isSigned) { + APInt Max = APInt::getSignedMaxValue(BitWidth); + if ((Max - CStep->getValue()->getValue()) + .slt(CLimit->getValue()->getValue())) + return UnknownValue; + } else { + APInt Max = APInt::getMaxValue(BitWidth); + if ((Max - CStep->getValue()->getValue()) + .ult(CLimit->getValue()->getValue())) + return UnknownValue; + } + } else + // TODO: handle non-constant limit values below. + return UnknownValue; + } else + // TODO: handle negative strides below. + return UnknownValue; + + // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant + // m. So, we count the number of iterations in which {n,+,s} < m is true. + // Note that we cannot simply return max(m-n,0)/s because it's not safe to + // treat m-n as signed nor unsigned due to overflow possibility. + + // First, we get the value of the LHS in the first iteration: n + SCEVHandle Start = AddRec->getOperand(0); + + // Determine the minimum constant start value. + SCEVHandle MinStart = isa<SCEVConstant>(Start) ? Start : + getConstant(isSigned ? APInt::getSignedMinValue(BitWidth) : + APInt::getMinValue(BitWidth)); + + // If we know that the condition is true in order to enter the loop, + // then we know that it will run exactly (m-n)/s times. Otherwise, we + // only know that it will execute (max(m,n)-n)/s times. In both cases, + // the division must round up. + SCEVHandle End = RHS; + if (!isLoopGuardedByCond(L, + isSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, + getMinusSCEV(Start, Step), RHS)) + End = isSigned ? getSMaxExpr(RHS, Start) + : getUMaxExpr(RHS, Start); + + // Determine the maximum constant end value. + SCEVHandle MaxEnd = isa<SCEVConstant>(End) ? End : + getConstant(isSigned ? APInt::getSignedMaxValue(BitWidth) : + APInt::getMaxValue(BitWidth)); + + // Finally, we subtract these two values and divide, rounding up, to get + // the number of times the backedge is executed. + SCEVHandle BECount = getUDivExpr(getAddExpr(getMinusSCEV(End, Start), + getAddExpr(Step, NegOne)), + Step); + + // The maximum backedge count is similar, except using the minimum start + // value and the maximum end value. + SCEVHandle MaxBECount = getUDivExpr(getAddExpr(getMinusSCEV(MaxEnd, + MinStart), + getAddExpr(Step, NegOne)), + Step); + + return BackedgeTakenInfo(BECount, MaxBECount); + } + + return UnknownValue; +} + +/// getNumIterationsInRange - Return the number of iterations of this loop that +/// produce values in the specified constant range. Another way of looking at +/// this is that it returns the first iteration number where the value is not in +/// the condition, thus computing the exit count. If the iteration count can't +/// be computed, an instance of SCEVCouldNotCompute is returned. +SCEVHandle SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, + ScalarEvolution &SE) const { + if (Range.isFullSet()) // Infinite loop. + return SE.getCouldNotCompute(); + + // If the start is a non-zero constant, shift the range to simplify things. + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart())) + if (!SC->getValue()->isZero()) { + std::vector<SCEVHandle> Operands(op_begin(), op_end()); + Operands[0] = SE.getIntegerSCEV(0, SC->getType()); + SCEVHandle Shifted = SE.getAddRecExpr(Operands, getLoop()); + if (const SCEVAddRecExpr *ShiftedAddRec = + dyn_cast<SCEVAddRecExpr>(Shifted)) + return ShiftedAddRec->getNumIterationsInRange( + Range.subtract(SC->getValue()->getValue()), SE); + // This is strange and shouldn't happen. + return SE.getCouldNotCompute(); + } + + // The only time we can solve this is when we have all constant indices. + // Otherwise, we cannot determine the overflow conditions. + for (unsigned i = 0, e = getNumOperands(); i != e; ++i) + if (!isa<SCEVConstant>(getOperand(i))) + return SE.getCouldNotCompute(); + + + // Okay at this point we know that all elements of the chrec are constants and + // that the start element is zero. + + // First check to see if the range contains zero. If not, the first + // iteration exits. + unsigned BitWidth = SE.getTypeSizeInBits(getType()); + if (!Range.contains(APInt(BitWidth, 0))) + return SE.getConstant(ConstantInt::get(getType(),0)); + + if (isAffine()) { + // If this is an affine expression then we have this situation: + // Solve {0,+,A} in Range === Ax in Range + + // We know that zero is in the range. If A is positive then we know that + // the upper value of the range must be the first possible exit value. + // If A is negative then the lower of the range is the last possible loop + // value. Also note that we already checked for a full range. + APInt One(BitWidth,1); + APInt A = cast<SCEVConstant>(getOperand(1))->getValue()->getValue(); + APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower(); + + // The exit value should be (End+A)/A. + APInt ExitVal = (End + A).udiv(A); + ConstantInt *ExitValue = ConstantInt::get(ExitVal); + + // Evaluate at the exit value. If we really did fall out of the valid + // range, then we computed our trip count, otherwise wrap around or other + // things must have happened. + ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE); + if (Range.contains(Val->getValue())) + return SE.getCouldNotCompute(); // Something strange happened + + // Ensure that the previous value is in the range. This is a sanity check. + assert(Range.contains( + EvaluateConstantChrecAtConstant(this, + ConstantInt::get(ExitVal - One), SE)->getValue()) && + "Linear scev computation is off in a bad way!"); + return SE.getConstant(ExitValue); + } else if (isQuadratic()) { + // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the + // quadratic equation to solve it. To do this, we must frame our problem in + // terms of figuring out when zero is crossed, instead of when + // Range.getUpper() is crossed. + std::vector<SCEVHandle> NewOps(op_begin(), op_end()); + NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper())); + SCEVHandle NewAddRec = SE.getAddRecExpr(NewOps, getLoop()); + + // Next, solve the constructed addrec + std::pair<SCEVHandle,SCEVHandle> Roots = + SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE); + const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first); + const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second); + if (R1) { + // Pick the smallest positive root value. + if (ConstantInt *CB = + dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT, + R1->getValue(), R2->getValue()))) { + if (CB->getZExtValue() == false) + std::swap(R1, R2); // R1 is the minimum root now. + + // Make sure the root is not off by one. The returned iteration should + // not be in the range, but the previous one should be. When solving + // for "X*X < 5", for example, we should not return a root of 2. + ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this, + R1->getValue(), + SE); + if (Range.contains(R1Val->getValue())) { + // The next iteration must be out of the range... + ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()+1); + + R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); + if (!Range.contains(R1Val->getValue())) + return SE.getConstant(NextVal); + return SE.getCouldNotCompute(); // Something strange happened + } + + // If R1 was not in the range, then it is a good return value. Make + // sure that R1-1 WAS in the range though, just in case. + ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()-1); + R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); + if (Range.contains(R1Val->getValue())) + return R1; + return SE.getCouldNotCompute(); // Something strange happened + } + } + } + + return SE.getCouldNotCompute(); +} + + + +//===----------------------------------------------------------------------===// +// SCEVCallbackVH Class Implementation +//===----------------------------------------------------------------------===// + +void ScalarEvolution::SCEVCallbackVH::deleted() { + assert(SE && "SCEVCallbackVH called with a non-null ScalarEvolution!"); + if (PHINode *PN = dyn_cast<PHINode>(getValPtr())) + SE->ConstantEvolutionLoopExitValue.erase(PN); + if (Instruction *I = dyn_cast<Instruction>(getValPtr())) + SE->ValuesAtScopes.erase(I); + SE->Scalars.erase(getValPtr()); + // this now dangles! +} + +void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *) { + assert(SE && "SCEVCallbackVH called with a non-null ScalarEvolution!"); + + // Forget all the expressions associated with users of the old value, + // so that future queries will recompute the expressions using the new + // value. + SmallVector<User *, 16> Worklist; + Value *Old = getValPtr(); + bool DeleteOld = false; + for (Value::use_iterator UI = Old->use_begin(), UE = Old->use_end(); + UI != UE; ++UI) + Worklist.push_back(*UI); + while (!Worklist.empty()) { + User *U = Worklist.pop_back_val(); + // Deleting the Old value will cause this to dangle. Postpone + // that until everything else is done. + if (U == Old) { + DeleteOld = true; + continue; + } + if (PHINode *PN = dyn_cast<PHINode>(U)) + SE->ConstantEvolutionLoopExitValue.erase(PN); + if (Instruction *I = dyn_cast<Instruction>(U)) + SE->ValuesAtScopes.erase(I); + if (SE->Scalars.erase(U)) + for (Value::use_iterator UI = U->use_begin(), UE = U->use_end(); + UI != UE; ++UI) + Worklist.push_back(*UI); + } + if (DeleteOld) { + if (PHINode *PN = dyn_cast<PHINode>(Old)) + SE->ConstantEvolutionLoopExitValue.erase(PN); + if (Instruction *I = dyn_cast<Instruction>(Old)) + SE->ValuesAtScopes.erase(I); + SE->Scalars.erase(Old); + // this now dangles! + } + // this may dangle! +} + +ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se) + : CallbackVH(V), SE(se) {} + +//===----------------------------------------------------------------------===// +// ScalarEvolution Class Implementation +//===----------------------------------------------------------------------===// + +ScalarEvolution::ScalarEvolution() + : FunctionPass(&ID), UnknownValue(new SCEVCouldNotCompute()) { +} + +bool ScalarEvolution::runOnFunction(Function &F) { + this->F = &F; + LI = &getAnalysis<LoopInfo>(); + TD = getAnalysisIfAvailable<TargetData>(); + return false; +} + +void ScalarEvolution::releaseMemory() { + Scalars.clear(); + BackedgeTakenCounts.clear(); + ConstantEvolutionLoopExitValue.clear(); + ValuesAtScopes.clear(); +} + +void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequiredTransitive<LoopInfo>(); +} + +bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) { + return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L)); +} + +static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, + const Loop *L) { + // Print all inner loops first + for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) + PrintLoopInfo(OS, SE, *I); + + OS << "Loop " << L->getHeader()->getName() << ": "; + + SmallVector<BasicBlock*, 8> ExitBlocks; + L->getExitBlocks(ExitBlocks); + if (ExitBlocks.size() != 1) + OS << "<multiple exits> "; + + if (SE->hasLoopInvariantBackedgeTakenCount(L)) { + OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L); + } else { + OS << "Unpredictable backedge-taken count. "; + } + + OS << "\n"; +} + +void ScalarEvolution::print(raw_ostream &OS, const Module* ) const { + // ScalarEvolution's implementaiton of the print method is to print + // out SCEV values of all instructions that are interesting. Doing + // this potentially causes it to create new SCEV objects though, + // which technically conflicts with the const qualifier. This isn't + // observable from outside the class though (the hasSCEV function + // notwithstanding), so casting away the const isn't dangerous. + ScalarEvolution &SE = *const_cast<ScalarEvolution*>(this); + + OS << "Classifying expressions for: " << F->getName() << "\n"; + for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) + if (isSCEVable(I->getType())) { + OS << *I; + OS << " --> "; + SCEVHandle SV = SE.getSCEV(&*I); + SV->print(OS); + OS << "\t\t"; + + if (const Loop *L = LI->getLoopFor((*I).getParent())) { + OS << "Exits: "; + SCEVHandle ExitValue = SE.getSCEVAtScope(&*I, L->getParentLoop()); + if (!ExitValue->isLoopInvariant(L)) { + OS << "<<Unknown>>"; + } else { + OS << *ExitValue; + } + } + + OS << "\n"; + } + + OS << "Determining loop execution counts for: " << F->getName() << "\n"; + for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) + PrintLoopInfo(OS, &SE, *I); +} + +void ScalarEvolution::print(std::ostream &o, const Module *M) const { + raw_os_ostream OS(o); + print(OS, M); +} diff --git a/lib/Analysis/ScalarEvolutionExpander.cpp b/lib/Analysis/ScalarEvolutionExpander.cpp new file mode 100644 index 0000000..7ba8268 --- /dev/null +++ b/lib/Analysis/ScalarEvolutionExpander.cpp @@ -0,0 +1,646 @@ +//===- ScalarEvolutionExpander.cpp - Scalar Evolution Analysis --*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains the implementation of the scalar evolution expander, +// which is used to generate the code corresponding to a given scalar evolution +// expression. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Target/TargetData.h" +using namespace llvm; + +/// InsertCastOfTo - Insert a cast of V to the specified type, doing what +/// we can to share the casts. +Value *SCEVExpander::InsertCastOfTo(Instruction::CastOps opcode, Value *V, + const Type *Ty) { + // Short-circuit unnecessary bitcasts. + if (opcode == Instruction::BitCast && V->getType() == Ty) + return V; + + // Short-circuit unnecessary inttoptr<->ptrtoint casts. + if ((opcode == Instruction::PtrToInt || opcode == Instruction::IntToPtr) && + SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(V->getType())) { + if (CastInst *CI = dyn_cast<CastInst>(V)) + if ((CI->getOpcode() == Instruction::PtrToInt || + CI->getOpcode() == Instruction::IntToPtr) && + SE.getTypeSizeInBits(CI->getType()) == + SE.getTypeSizeInBits(CI->getOperand(0)->getType())) + return CI->getOperand(0); + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) + if ((CE->getOpcode() == Instruction::PtrToInt || + CE->getOpcode() == Instruction::IntToPtr) && + SE.getTypeSizeInBits(CE->getType()) == + SE.getTypeSizeInBits(CE->getOperand(0)->getType())) + return CE->getOperand(0); + } + + // FIXME: keep track of the cast instruction. + if (Constant *C = dyn_cast<Constant>(V)) + return ConstantExpr::getCast(opcode, C, Ty); + + if (Argument *A = dyn_cast<Argument>(V)) { + // Check to see if there is already a cast! + for (Value::use_iterator UI = A->use_begin(), E = A->use_end(); + UI != E; ++UI) { + if ((*UI)->getType() == Ty) + if (CastInst *CI = dyn_cast<CastInst>(cast<Instruction>(*UI))) + if (CI->getOpcode() == opcode) { + // If the cast isn't the first instruction of the function, move it. + if (BasicBlock::iterator(CI) != + A->getParent()->getEntryBlock().begin()) { + // If the CastInst is the insert point, change the insert point. + if (CI == InsertPt) ++InsertPt; + // Splice the cast at the beginning of the entry block. + CI->moveBefore(A->getParent()->getEntryBlock().begin()); + } + return CI; + } + } + Instruction *I = CastInst::Create(opcode, V, Ty, V->getName(), + A->getParent()->getEntryBlock().begin()); + InsertedValues.insert(I); + return I; + } + + Instruction *I = cast<Instruction>(V); + + // Check to see if there is already a cast. If there is, use it. + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); + UI != E; ++UI) { + if ((*UI)->getType() == Ty) + if (CastInst *CI = dyn_cast<CastInst>(cast<Instruction>(*UI))) + if (CI->getOpcode() == opcode) { + BasicBlock::iterator It = I; ++It; + if (isa<InvokeInst>(I)) + It = cast<InvokeInst>(I)->getNormalDest()->begin(); + while (isa<PHINode>(It)) ++It; + if (It != BasicBlock::iterator(CI)) { + // If the CastInst is the insert point, change the insert point. + if (CI == InsertPt) ++InsertPt; + // Splice the cast immediately after the operand in question. + CI->moveBefore(It); + } + return CI; + } + } + BasicBlock::iterator IP = I; ++IP; + if (InvokeInst *II = dyn_cast<InvokeInst>(I)) + IP = II->getNormalDest()->begin(); + while (isa<PHINode>(IP)) ++IP; + Instruction *CI = CastInst::Create(opcode, V, Ty, V->getName(), IP); + InsertedValues.insert(CI); + return CI; +} + +/// InsertNoopCastOfTo - Insert a cast of V to the specified type, +/// which must be possible with a noop cast. +Value *SCEVExpander::InsertNoopCastOfTo(Value *V, const Type *Ty) { + Instruction::CastOps Op = CastInst::getCastOpcode(V, false, Ty, false); + assert((Op == Instruction::BitCast || + Op == Instruction::PtrToInt || + Op == Instruction::IntToPtr) && + "InsertNoopCastOfTo cannot perform non-noop casts!"); + assert(SE.getTypeSizeInBits(V->getType()) == SE.getTypeSizeInBits(Ty) && + "InsertNoopCastOfTo cannot change sizes!"); + return InsertCastOfTo(Op, V, Ty); +} + +/// InsertBinop - Insert the specified binary operator, doing a small amount +/// of work to avoid inserting an obviously redundant operation. +Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode, Value *LHS, + Value *RHS, BasicBlock::iterator InsertPt) { + // Fold a binop with constant operands. + if (Constant *CLHS = dyn_cast<Constant>(LHS)) + if (Constant *CRHS = dyn_cast<Constant>(RHS)) + return ConstantExpr::get(Opcode, CLHS, CRHS); + + // Do a quick scan to see if we have this binop nearby. If so, reuse it. + unsigned ScanLimit = 6; + BasicBlock::iterator BlockBegin = InsertPt->getParent()->begin(); + if (InsertPt != BlockBegin) { + // Scanning starts from the last instruction before InsertPt. + BasicBlock::iterator IP = InsertPt; + --IP; + for (; ScanLimit; --IP, --ScanLimit) { + if (IP->getOpcode() == (unsigned)Opcode && IP->getOperand(0) == LHS && + IP->getOperand(1) == RHS) + return IP; + if (IP == BlockBegin) break; + } + } + + // If we haven't found this binop, insert it. + Instruction *BO = BinaryOperator::Create(Opcode, LHS, RHS, "tmp", InsertPt); + InsertedValues.insert(BO); + return BO; +} + +/// FactorOutConstant - Test if S is divisible by Factor, using signed +/// division. If so, update S with Factor divided out and return true. +/// S need not be evenly divisble if a reasonable remainder can be +/// computed. +/// TODO: When ScalarEvolution gets a SCEVSDivExpr, this can be made +/// unnecessary; in its place, just signed-divide Ops[i] by the scale and +/// check to see if the divide was folded. +static bool FactorOutConstant(SCEVHandle &S, + SCEVHandle &Remainder, + const APInt &Factor, + ScalarEvolution &SE) { + // Everything is divisible by one. + if (Factor == 1) + return true; + + // For a Constant, check for a multiple of the given factor. + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) { + ConstantInt *CI = + ConstantInt::get(C->getValue()->getValue().sdiv(Factor)); + // If the quotient is zero and the remainder is non-zero, reject + // the value at this scale. It will be considered for subsequent + // smaller scales. + if (C->isZero() || !CI->isZero()) { + SCEVHandle Div = SE.getConstant(CI); + S = Div; + Remainder = + SE.getAddExpr(Remainder, + SE.getConstant(C->getValue()->getValue().srem(Factor))); + return true; + } + } + + // In a Mul, check if there is a constant operand which is a multiple + // of the given factor. + if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(M->getOperand(0))) + if (!C->getValue()->getValue().srem(Factor)) { + std::vector<SCEVHandle> NewMulOps(M->getOperands()); + NewMulOps[0] = + SE.getConstant(C->getValue()->getValue().sdiv(Factor)); + S = SE.getMulExpr(NewMulOps); + return true; + } + + // In an AddRec, check if both start and step are divisible. + if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) { + SCEVHandle Step = A->getStepRecurrence(SE); + SCEVHandle StepRem = SE.getIntegerSCEV(0, Step->getType()); + if (!FactorOutConstant(Step, StepRem, Factor, SE)) + return false; + if (!StepRem->isZero()) + return false; + SCEVHandle Start = A->getStart(); + if (!FactorOutConstant(Start, Remainder, Factor, SE)) + return false; + S = SE.getAddRecExpr(Start, Step, A->getLoop()); + return true; + } + + return false; +} + +/// expandAddToGEP - Expand a SCEVAddExpr with a pointer type into a GEP +/// instead of using ptrtoint+arithmetic+inttoptr. This helps +/// BasicAliasAnalysis analyze the result. However, it suffers from the +/// underlying bug described in PR2831. Addition in LLVM currently always +/// has two's complement wrapping guaranteed. However, the semantics for +/// getelementptr overflow are ambiguous. In the common case though, this +/// expansion gets used when a GEP in the original code has been converted +/// into integer arithmetic, in which case the resulting code will be no +/// more undefined than it was originally. +/// +/// Design note: It might seem desirable for this function to be more +/// loop-aware. If some of the indices are loop-invariant while others +/// aren't, it might seem desirable to emit multiple GEPs, keeping the +/// loop-invariant portions of the overall computation outside the loop. +/// However, there are a few reasons this is not done here. Hoisting simple +/// arithmetic is a low-level optimization that often isn't very +/// important until late in the optimization process. In fact, passes +/// like InstructionCombining will combine GEPs, even if it means +/// pushing loop-invariant computation down into loops, so even if the +/// GEPs were split here, the work would quickly be undone. The +/// LoopStrengthReduction pass, which is usually run quite late (and +/// after the last InstructionCombining pass), takes care of hoisting +/// loop-invariant portions of expressions, after considering what +/// can be folded using target addressing modes. +/// +Value *SCEVExpander::expandAddToGEP(const SCEVHandle *op_begin, + const SCEVHandle *op_end, + const PointerType *PTy, + const Type *Ty, + Value *V) { + const Type *ElTy = PTy->getElementType(); + SmallVector<Value *, 4> GepIndices; + std::vector<SCEVHandle> Ops(op_begin, op_end); + bool AnyNonZeroIndices = false; + + // Decend down the pointer's type and attempt to convert the other + // operands into GEP indices, at each level. The first index in a GEP + // indexes into the array implied by the pointer operand; the rest of + // the indices index into the element or field type selected by the + // preceding index. + for (;;) { + APInt ElSize = APInt(SE.getTypeSizeInBits(Ty), + ElTy->isSized() ? SE.TD->getTypeAllocSize(ElTy) : 0); + std::vector<SCEVHandle> NewOps; + std::vector<SCEVHandle> ScaledOps; + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + // Split AddRecs up into parts as either of the parts may be usable + // without the other. + if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(Ops[i])) + if (!A->getStart()->isZero()) { + SCEVHandle Start = A->getStart(); + Ops.push_back(SE.getAddRecExpr(SE.getIntegerSCEV(0, A->getType()), + A->getStepRecurrence(SE), + A->getLoop())); + Ops[i] = Start; + ++e; + } + // If the scale size is not 0, attempt to factor out a scale. + if (ElSize != 0) { + SCEVHandle Op = Ops[i]; + SCEVHandle Remainder = SE.getIntegerSCEV(0, Op->getType()); + if (FactorOutConstant(Op, Remainder, ElSize, SE)) { + ScaledOps.push_back(Op); // Op now has ElSize factored out. + NewOps.push_back(Remainder); + continue; + } + } + // If the operand was not divisible, add it to the list of operands + // we'll scan next iteration. + NewOps.push_back(Ops[i]); + } + Ops = NewOps; + AnyNonZeroIndices |= !ScaledOps.empty(); + Value *Scaled = ScaledOps.empty() ? + Constant::getNullValue(Ty) : + expandCodeFor(SE.getAddExpr(ScaledOps), Ty); + GepIndices.push_back(Scaled); + + // Collect struct field index operands. + if (!Ops.empty()) + while (const StructType *STy = dyn_cast<StructType>(ElTy)) { + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[0])) + if (SE.getTypeSizeInBits(C->getType()) <= 64) { + const StructLayout &SL = *SE.TD->getStructLayout(STy); + uint64_t FullOffset = C->getValue()->getZExtValue(); + if (FullOffset < SL.getSizeInBytes()) { + unsigned ElIdx = SL.getElementContainingOffset(FullOffset); + GepIndices.push_back(ConstantInt::get(Type::Int32Ty, ElIdx)); + ElTy = STy->getTypeAtIndex(ElIdx); + Ops[0] = + SE.getConstant(ConstantInt::get(Ty, + FullOffset - + SL.getElementOffset(ElIdx))); + AnyNonZeroIndices = true; + continue; + } + } + break; + } + + if (const ArrayType *ATy = dyn_cast<ArrayType>(ElTy)) { + ElTy = ATy->getElementType(); + continue; + } + break; + } + + // If none of the operands were convertable to proper GEP indices, cast + // the base to i8* and do an ugly getelementptr with that. It's still + // better than ptrtoint+arithmetic+inttoptr at least. + if (!AnyNonZeroIndices) { + V = InsertNoopCastOfTo(V, + Type::Int8Ty->getPointerTo(PTy->getAddressSpace())); + Value *Idx = expand(SE.getAddExpr(Ops)); + Idx = InsertNoopCastOfTo(Idx, Ty); + + // Fold a GEP with constant operands. + if (Constant *CLHS = dyn_cast<Constant>(V)) + if (Constant *CRHS = dyn_cast<Constant>(Idx)) + return ConstantExpr::getGetElementPtr(CLHS, &CRHS, 1); + + // Do a quick scan to see if we have this GEP nearby. If so, reuse it. + unsigned ScanLimit = 6; + BasicBlock::iterator BlockBegin = InsertPt->getParent()->begin(); + if (InsertPt != BlockBegin) { + // Scanning starts from the last instruction before InsertPt. + BasicBlock::iterator IP = InsertPt; + --IP; + for (; ScanLimit; --IP, --ScanLimit) { + if (IP->getOpcode() == Instruction::GetElementPtr && + IP->getOperand(0) == V && IP->getOperand(1) == Idx) + return IP; + if (IP == BlockBegin) break; + } + } + + Value *GEP = GetElementPtrInst::Create(V, Idx, "scevgep", InsertPt); + InsertedValues.insert(GEP); + return GEP; + } + + // Insert a pretty getelementptr. + Value *GEP = GetElementPtrInst::Create(V, + GepIndices.begin(), + GepIndices.end(), + "scevgep", InsertPt); + Ops.push_back(SE.getUnknown(GEP)); + InsertedValues.insert(GEP); + return expand(SE.getAddExpr(Ops)); +} + +Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { + const Type *Ty = SE.getEffectiveSCEVType(S->getType()); + Value *V = expand(S->getOperand(S->getNumOperands()-1)); + + // Turn things like ptrtoint+arithmetic+inttoptr into GEP. See the + // comments on expandAddToGEP for details. + if (SE.TD) + if (const PointerType *PTy = dyn_cast<PointerType>(V->getType())) { + const std::vector<SCEVHandle> &Ops = S->getOperands(); + return expandAddToGEP(&Ops[0], &Ops[Ops.size() - 1], + PTy, Ty, V); + } + + V = InsertNoopCastOfTo(V, Ty); + + // Emit a bunch of add instructions + for (int i = S->getNumOperands()-2; i >= 0; --i) { + Value *W = expand(S->getOperand(i)); + W = InsertNoopCastOfTo(W, Ty); + V = InsertBinop(Instruction::Add, V, W, InsertPt); + } + return V; +} + +Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) { + const Type *Ty = SE.getEffectiveSCEVType(S->getType()); + int FirstOp = 0; // Set if we should emit a subtract. + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getOperand(0))) + if (SC->getValue()->isAllOnesValue()) + FirstOp = 1; + + int i = S->getNumOperands()-2; + Value *V = expand(S->getOperand(i+1)); + V = InsertNoopCastOfTo(V, Ty); + + // Emit a bunch of multiply instructions + for (; i >= FirstOp; --i) { + Value *W = expand(S->getOperand(i)); + W = InsertNoopCastOfTo(W, Ty); + V = InsertBinop(Instruction::Mul, V, W, InsertPt); + } + + // -1 * ... ---> 0 - ... + if (FirstOp == 1) + V = InsertBinop(Instruction::Sub, Constant::getNullValue(Ty), V, InsertPt); + return V; +} + +Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) { + const Type *Ty = SE.getEffectiveSCEVType(S->getType()); + + Value *LHS = expand(S->getLHS()); + LHS = InsertNoopCastOfTo(LHS, Ty); + if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getRHS())) { + const APInt &RHS = SC->getValue()->getValue(); + if (RHS.isPowerOf2()) + return InsertBinop(Instruction::LShr, LHS, + ConstantInt::get(Ty, RHS.logBase2()), + InsertPt); + } + + Value *RHS = expand(S->getRHS()); + RHS = InsertNoopCastOfTo(RHS, Ty); + return InsertBinop(Instruction::UDiv, LHS, RHS, InsertPt); +} + +/// Move parts of Base into Rest to leave Base with the minimal +/// expression that provides a pointer operand suitable for a +/// GEP expansion. +static void ExposePointerBase(SCEVHandle &Base, SCEVHandle &Rest, + ScalarEvolution &SE) { + while (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(Base)) { + Base = A->getStart(); + Rest = SE.getAddExpr(Rest, + SE.getAddRecExpr(SE.getIntegerSCEV(0, A->getType()), + A->getStepRecurrence(SE), + A->getLoop())); + } + if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(Base)) { + Base = A->getOperand(A->getNumOperands()-1); + std::vector<SCEVHandle> NewAddOps(A->op_begin(), A->op_end()); + NewAddOps.back() = Rest; + Rest = SE.getAddExpr(NewAddOps); + ExposePointerBase(Base, Rest, SE); + } +} + +Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { + const Type *Ty = SE.getEffectiveSCEVType(S->getType()); + const Loop *L = S->getLoop(); + + // {X,+,F} --> X + {0,+,F} + if (!S->getStart()->isZero()) { + std::vector<SCEVHandle> NewOps(S->getOperands()); + NewOps[0] = SE.getIntegerSCEV(0, Ty); + SCEVHandle Rest = SE.getAddRecExpr(NewOps, L); + + // Turn things like ptrtoint+arithmetic+inttoptr into GEP. See the + // comments on expandAddToGEP for details. + if (SE.TD) { + SCEVHandle Base = S->getStart(); + SCEVHandle RestArray[1] = { Rest }; + // Dig into the expression to find the pointer base for a GEP. + ExposePointerBase(Base, RestArray[0], SE); + // If we found a pointer, expand the AddRec with a GEP. + if (const PointerType *PTy = dyn_cast<PointerType>(Base->getType())) { + // Make sure the Base isn't something exotic, such as a multiplied + // or divided pointer value. In those cases, the result type isn't + // actually a pointer type. + if (!isa<SCEVMulExpr>(Base) && !isa<SCEVUDivExpr>(Base)) { + Value *StartV = expand(Base); + assert(StartV->getType() == PTy && "Pointer type mismatch for GEP!"); + return expandAddToGEP(RestArray, RestArray+1, PTy, Ty, StartV); + } + } + } + + Value *RestV = expand(Rest); + return expand(SE.getAddExpr(S->getStart(), SE.getUnknown(RestV))); + } + + // {0,+,1} --> Insert a canonical induction variable into the loop! + if (S->isAffine() && + S->getOperand(1) == SE.getIntegerSCEV(1, Ty)) { + // Create and insert the PHI node for the induction variable in the + // specified loop. + BasicBlock *Header = L->getHeader(); + PHINode *PN = PHINode::Create(Ty, "indvar", Header->begin()); + InsertedValues.insert(PN); + PN->addIncoming(Constant::getNullValue(Ty), L->getLoopPreheader()); + + pred_iterator HPI = pred_begin(Header); + assert(HPI != pred_end(Header) && "Loop with zero preds???"); + if (!L->contains(*HPI)) ++HPI; + assert(HPI != pred_end(Header) && L->contains(*HPI) && + "No backedge in loop?"); + + // Insert a unit add instruction right before the terminator corresponding + // to the back-edge. + Constant *One = ConstantInt::get(Ty, 1); + Instruction *Add = BinaryOperator::CreateAdd(PN, One, "indvar.next", + (*HPI)->getTerminator()); + InsertedValues.insert(Add); + + pred_iterator PI = pred_begin(Header); + if (*PI == L->getLoopPreheader()) + ++PI; + PN->addIncoming(Add, *PI); + return PN; + } + + // Get the canonical induction variable I for this loop. + Value *I = getOrInsertCanonicalInductionVariable(L, Ty); + + // If this is a simple linear addrec, emit it now as a special case. + if (S->isAffine()) { // {0,+,F} --> i*F + Value *F = expand(S->getOperand(1)); + F = InsertNoopCastOfTo(F, Ty); + + // IF the step is by one, just return the inserted IV. + if (ConstantInt *CI = dyn_cast<ConstantInt>(F)) + if (CI->getValue() == 1) + return I; + + // If the insert point is directly inside of the loop, emit the multiply at + // the insert point. Otherwise, L is a loop that is a parent of the insert + // point loop. If we can, move the multiply to the outer most loop that it + // is safe to be in. + BasicBlock::iterator MulInsertPt = getInsertionPoint(); + Loop *InsertPtLoop = SE.LI->getLoopFor(MulInsertPt->getParent()); + if (InsertPtLoop != L && InsertPtLoop && + L->contains(InsertPtLoop->getHeader())) { + do { + // If we cannot hoist the multiply out of this loop, don't. + if (!InsertPtLoop->isLoopInvariant(F)) break; + + BasicBlock *InsertPtLoopPH = InsertPtLoop->getLoopPreheader(); + + // If this loop hasn't got a preheader, we aren't able to hoist the + // multiply. + if (!InsertPtLoopPH) + break; + + // Otherwise, move the insert point to the preheader. + MulInsertPt = InsertPtLoopPH->getTerminator(); + InsertPtLoop = InsertPtLoop->getParentLoop(); + } while (InsertPtLoop != L); + } + + return InsertBinop(Instruction::Mul, I, F, MulInsertPt); + } + + // If this is a chain of recurrences, turn it into a closed form, using the + // folders, then expandCodeFor the closed form. This allows the folders to + // simplify the expression without having to build a bunch of special code + // into this folder. + SCEVHandle IH = SE.getUnknown(I); // Get I as a "symbolic" SCEV. + + SCEVHandle V = S->evaluateAtIteration(IH, SE); + //cerr << "Evaluated: " << *this << "\n to: " << *V << "\n"; + + return expand(V); +} + +Value *SCEVExpander::visitTruncateExpr(const SCEVTruncateExpr *S) { + const Type *Ty = SE.getEffectiveSCEVType(S->getType()); + Value *V = expand(S->getOperand()); + V = InsertNoopCastOfTo(V, SE.getEffectiveSCEVType(V->getType())); + Instruction *I = new TruncInst(V, Ty, "tmp.", InsertPt); + InsertedValues.insert(I); + return I; +} + +Value *SCEVExpander::visitZeroExtendExpr(const SCEVZeroExtendExpr *S) { + const Type *Ty = SE.getEffectiveSCEVType(S->getType()); + Value *V = expand(S->getOperand()); + V = InsertNoopCastOfTo(V, SE.getEffectiveSCEVType(V->getType())); + Instruction *I = new ZExtInst(V, Ty, "tmp.", InsertPt); + InsertedValues.insert(I); + return I; +} + +Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) { + const Type *Ty = SE.getEffectiveSCEVType(S->getType()); + Value *V = expand(S->getOperand()); + V = InsertNoopCastOfTo(V, SE.getEffectiveSCEVType(V->getType())); + Instruction *I = new SExtInst(V, Ty, "tmp.", InsertPt); + InsertedValues.insert(I); + return I; +} + +Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) { + const Type *Ty = SE.getEffectiveSCEVType(S->getType()); + Value *LHS = expand(S->getOperand(0)); + LHS = InsertNoopCastOfTo(LHS, Ty); + for (unsigned i = 1; i < S->getNumOperands(); ++i) { + Value *RHS = expand(S->getOperand(i)); + RHS = InsertNoopCastOfTo(RHS, Ty); + Instruction *ICmp = + new ICmpInst(ICmpInst::ICMP_SGT, LHS, RHS, "tmp", InsertPt); + InsertedValues.insert(ICmp); + Instruction *Sel = SelectInst::Create(ICmp, LHS, RHS, "smax", InsertPt); + InsertedValues.insert(Sel); + LHS = Sel; + } + return LHS; +} + +Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) { + const Type *Ty = SE.getEffectiveSCEVType(S->getType()); + Value *LHS = expand(S->getOperand(0)); + LHS = InsertNoopCastOfTo(LHS, Ty); + for (unsigned i = 1; i < S->getNumOperands(); ++i) { + Value *RHS = expand(S->getOperand(i)); + RHS = InsertNoopCastOfTo(RHS, Ty); + Instruction *ICmp = + new ICmpInst(ICmpInst::ICMP_UGT, LHS, RHS, "tmp", InsertPt); + InsertedValues.insert(ICmp); + Instruction *Sel = SelectInst::Create(ICmp, LHS, RHS, "umax", InsertPt); + InsertedValues.insert(Sel); + LHS = Sel; + } + return LHS; +} + +Value *SCEVExpander::expandCodeFor(SCEVHandle SH, const Type *Ty) { + // Expand the code for this SCEV. + Value *V = expand(SH); + if (Ty) { + assert(SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(SH->getType()) && + "non-trivial casts should be done with the SCEVs directly!"); + V = InsertNoopCastOfTo(V, Ty); + } + return V; +} + +Value *SCEVExpander::expand(const SCEV *S) { + // Check to see if we already expanded this. + std::map<SCEVHandle, AssertingVH<Value> >::iterator I = + InsertedExpressions.find(S); + if (I != InsertedExpressions.end()) + return I->second; + + Value *V = visit(S); + InsertedExpressions[S] = V; + return V; +} diff --git a/lib/Analysis/SparsePropagation.cpp b/lib/Analysis/SparsePropagation.cpp new file mode 100644 index 0000000..5433068 --- /dev/null +++ b/lib/Analysis/SparsePropagation.cpp @@ -0,0 +1,331 @@ +//===- SparsePropagation.cpp - Sparse Conditional Property Propagation ----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements an abstract sparse conditional propagation algorithm, +// modeled after SCCP, but with a customizable lattice function. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "sparseprop" +#include "llvm/Analysis/SparsePropagation.h" +#include "llvm/Constants.h" +#include "llvm/Function.h" +#include "llvm/Instructions.h" +#include "llvm/Support/Debug.h" +using namespace llvm; + +//===----------------------------------------------------------------------===// +// AbstractLatticeFunction Implementation +//===----------------------------------------------------------------------===// + +AbstractLatticeFunction::~AbstractLatticeFunction() {} + +/// PrintValue - Render the specified lattice value to the specified stream. +void AbstractLatticeFunction::PrintValue(LatticeVal V, std::ostream &OS) { + if (V == UndefVal) + OS << "undefined"; + else if (V == OverdefinedVal) + OS << "overdefined"; + else if (V == UntrackedVal) + OS << "untracked"; + else + OS << "unknown lattice value"; +} + +//===----------------------------------------------------------------------===// +// SparseSolver Implementation +//===----------------------------------------------------------------------===// + +/// getOrInitValueState - Return the LatticeVal object that corresponds to the +/// value, initializing the value's state if it hasn't been entered into the +/// map yet. This function is necessary because not all values should start +/// out in the underdefined state... Arguments should be overdefined, and +/// constants should be marked as constants. +/// +SparseSolver::LatticeVal SparseSolver::getOrInitValueState(Value *V) { + DenseMap<Value*, LatticeVal>::iterator I = ValueState.find(V); + if (I != ValueState.end()) return I->second; // Common case, in the map + + LatticeVal LV; + if (LatticeFunc->IsUntrackedValue(V)) + return LatticeFunc->getUntrackedVal(); + else if (Constant *C = dyn_cast<Constant>(V)) + LV = LatticeFunc->ComputeConstant(C); + else if (Argument *A = dyn_cast<Argument>(V)) + LV = LatticeFunc->ComputeArgument(A); + else if (!isa<Instruction>(V)) + // All other non-instructions are overdefined. + LV = LatticeFunc->getOverdefinedVal(); + else + // All instructions are underdefined by default. + LV = LatticeFunc->getUndefVal(); + + // If this value is untracked, don't add it to the map. + if (LV == LatticeFunc->getUntrackedVal()) + return LV; + return ValueState[V] = LV; +} + +/// UpdateState - When the state for some instruction is potentially updated, +/// this function notices and adds I to the worklist if needed. +void SparseSolver::UpdateState(Instruction &Inst, LatticeVal V) { + DenseMap<Value*, LatticeVal>::iterator I = ValueState.find(&Inst); + if (I != ValueState.end() && I->second == V) + return; // No change. + + // An update. Visit uses of I. + ValueState[&Inst] = V; + InstWorkList.push_back(&Inst); +} + +/// MarkBlockExecutable - This method can be used by clients to mark all of +/// the blocks that are known to be intrinsically live in the processed unit. +void SparseSolver::MarkBlockExecutable(BasicBlock *BB) { + DOUT << "Marking Block Executable: " << BB->getNameStart() << "\n"; + BBExecutable.insert(BB); // Basic block is executable! + BBWorkList.push_back(BB); // Add the block to the work list! +} + +/// markEdgeExecutable - Mark a basic block as executable, adding it to the BB +/// work list if it is not already executable... +void SparseSolver::markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest) { + if (!KnownFeasibleEdges.insert(Edge(Source, Dest)).second) + return; // This edge is already known to be executable! + + DOUT << "Marking Edge Executable: " << Source->getNameStart() + << " -> " << Dest->getNameStart() << "\n"; + + if (BBExecutable.count(Dest)) { + // The destination is already executable, but we just made an edge + // feasible that wasn't before. Revisit the PHI nodes in the block + // because they have potentially new operands. + for (BasicBlock::iterator I = Dest->begin(); isa<PHINode>(I); ++I) + visitPHINode(*cast<PHINode>(I)); + + } else { + MarkBlockExecutable(Dest); + } +} + + +/// getFeasibleSuccessors - Return a vector of booleans to indicate which +/// successors are reachable from a given terminator instruction. +void SparseSolver::getFeasibleSuccessors(TerminatorInst &TI, + SmallVectorImpl<bool> &Succs, + bool AggressiveUndef) { + Succs.resize(TI.getNumSuccessors()); + if (TI.getNumSuccessors() == 0) return; + + if (BranchInst *BI = dyn_cast<BranchInst>(&TI)) { + if (BI->isUnconditional()) { + Succs[0] = true; + return; + } + + LatticeVal BCValue; + if (AggressiveUndef) + BCValue = getOrInitValueState(BI->getCondition()); + else + BCValue = getLatticeState(BI->getCondition()); + + if (BCValue == LatticeFunc->getOverdefinedVal() || + BCValue == LatticeFunc->getUntrackedVal()) { + // Overdefined condition variables can branch either way. + Succs[0] = Succs[1] = true; + return; + } + + // If undefined, neither is feasible yet. + if (BCValue == LatticeFunc->getUndefVal()) + return; + + Constant *C = LatticeFunc->GetConstant(BCValue, BI->getCondition(), *this); + if (C == 0 || !isa<ConstantInt>(C)) { + // Non-constant values can go either way. + Succs[0] = Succs[1] = true; + return; + } + + // Constant condition variables mean the branch can only go a single way + Succs[C == ConstantInt::getFalse()] = true; + return; + } + + if (isa<InvokeInst>(TI)) { + // Invoke instructions successors are always executable. + // TODO: Could ask the lattice function if the value can throw. + Succs[0] = Succs[1] = true; + return; + } + + SwitchInst &SI = cast<SwitchInst>(TI); + LatticeVal SCValue; + if (AggressiveUndef) + SCValue = getOrInitValueState(SI.getCondition()); + else + SCValue = getLatticeState(SI.getCondition()); + + if (SCValue == LatticeFunc->getOverdefinedVal() || + SCValue == LatticeFunc->getUntrackedVal()) { + // All destinations are executable! + Succs.assign(TI.getNumSuccessors(), true); + return; + } + + // If undefined, neither is feasible yet. + if (SCValue == LatticeFunc->getUndefVal()) + return; + + Constant *C = LatticeFunc->GetConstant(SCValue, SI.getCondition(), *this); + if (C == 0 || !isa<ConstantInt>(C)) { + // All destinations are executable! + Succs.assign(TI.getNumSuccessors(), true); + return; + } + + Succs[SI.findCaseValue(cast<ConstantInt>(C))] = true; +} + + +/// isEdgeFeasible - Return true if the control flow edge from the 'From' +/// basic block to the 'To' basic block is currently feasible... +bool SparseSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To, + bool AggressiveUndef) { + SmallVector<bool, 16> SuccFeasible; + TerminatorInst *TI = From->getTerminator(); + getFeasibleSuccessors(*TI, SuccFeasible, AggressiveUndef); + + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) + if (TI->getSuccessor(i) == To && SuccFeasible[i]) + return true; + + return false; +} + +void SparseSolver::visitTerminatorInst(TerminatorInst &TI) { + SmallVector<bool, 16> SuccFeasible; + getFeasibleSuccessors(TI, SuccFeasible, true); + + BasicBlock *BB = TI.getParent(); + + // Mark all feasible successors executable... + for (unsigned i = 0, e = SuccFeasible.size(); i != e; ++i) + if (SuccFeasible[i]) + markEdgeExecutable(BB, TI.getSuccessor(i)); +} + +void SparseSolver::visitPHINode(PHINode &PN) { + LatticeVal PNIV = getOrInitValueState(&PN); + LatticeVal Overdefined = LatticeFunc->getOverdefinedVal(); + + // If this value is already overdefined (common) just return. + if (PNIV == Overdefined || PNIV == LatticeFunc->getUntrackedVal()) + return; // Quick exit + + // Super-extra-high-degree PHI nodes are unlikely to ever be interesting, + // and slow us down a lot. Just mark them overdefined. + if (PN.getNumIncomingValues() > 64) { + UpdateState(PN, Overdefined); + return; + } + + // Look at all of the executable operands of the PHI node. If any of them + // are overdefined, the PHI becomes overdefined as well. Otherwise, ask the + // transfer function to give us the merge of the incoming values. + for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { + // If the edge is not yet known to be feasible, it doesn't impact the PHI. + if (!isEdgeFeasible(PN.getIncomingBlock(i), PN.getParent(), true)) + continue; + + // Merge in this value. + LatticeVal OpVal = getOrInitValueState(PN.getIncomingValue(i)); + if (OpVal != PNIV) + PNIV = LatticeFunc->MergeValues(PNIV, OpVal); + + if (PNIV == Overdefined) + break; // Rest of input values don't matter. + } + + // Update the PHI with the compute value, which is the merge of the inputs. + UpdateState(PN, PNIV); +} + + +void SparseSolver::visitInst(Instruction &I) { + // PHIs are handled by the propagation logic, they are never passed into the + // transfer functions. + if (PHINode *PN = dyn_cast<PHINode>(&I)) + return visitPHINode(*PN); + + // Otherwise, ask the transfer function what the result is. If this is + // something that we care about, remember it. + LatticeVal IV = LatticeFunc->ComputeInstructionState(I, *this); + if (IV != LatticeFunc->getUntrackedVal()) + UpdateState(I, IV); + + if (TerminatorInst *TI = dyn_cast<TerminatorInst>(&I)) + visitTerminatorInst(*TI); +} + +void SparseSolver::Solve(Function &F) { + MarkBlockExecutable(&F.getEntryBlock()); + + // Process the work lists until they are empty! + while (!BBWorkList.empty() || !InstWorkList.empty()) { + // Process the instruction work list. + while (!InstWorkList.empty()) { + Instruction *I = InstWorkList.back(); + InstWorkList.pop_back(); + + DOUT << "\nPopped off I-WL: " << *I; + + // "I" got into the work list because it made a transition. See if any + // users are both live and in need of updating. + for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); + UI != E; ++UI) { + Instruction *U = cast<Instruction>(*UI); + if (BBExecutable.count(U->getParent())) // Inst is executable? + visitInst(*U); + } + } + + // Process the basic block work list. + while (!BBWorkList.empty()) { + BasicBlock *BB = BBWorkList.back(); + BBWorkList.pop_back(); + + DOUT << "\nPopped off BBWL: " << *BB; + + // Notify all instructions in this basic block that they are newly + // executable. + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) + visitInst(*I); + } + } +} + +void SparseSolver::Print(Function &F, std::ostream &OS) const { + OS << "\nFUNCTION: " << F.getNameStr() << "\n"; + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { + if (!BBExecutable.count(BB)) + OS << "INFEASIBLE: "; + OS << "\t"; + if (BB->hasName()) + OS << BB->getNameStr() << ":\n"; + else + OS << "; anon bb\n"; + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { + LatticeFunc->PrintValue(getLatticeState(I), OS); + OS << *I; + } + + OS << "\n"; + } +} + diff --git a/lib/Analysis/Trace.cpp b/lib/Analysis/Trace.cpp new file mode 100644 index 0000000..8f19fda --- /dev/null +++ b/lib/Analysis/Trace.cpp @@ -0,0 +1,50 @@ +//===- Trace.cpp - Implementation of Trace class --------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This class represents a single trace of LLVM basic blocks. A trace is a +// single entry, multiple exit, region of code that is often hot. Trace-based +// optimizations treat traces almost like they are a large, strange, basic +// block: because the trace path is assumed to be hot, optimizations for the +// fall-through path are made at the expense of the non-fall-through paths. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Trace.h" +#include "llvm/Function.h" +#include "llvm/Assembly/Writer.h" +#include "llvm/Support/Streams.h" +using namespace llvm; + +Function *Trace::getFunction() const { + return getEntryBasicBlock()->getParent(); +} + +Module *Trace::getModule() const { + return getFunction()->getParent(); +} + +/// print - Write trace to output stream. +/// +void Trace::print(std::ostream &O) const { + Function *F = getFunction (); + O << "; Trace from function " << F->getName() << ", blocks:\n"; + for (const_iterator i = begin(), e = end(); i != e; ++i) { + O << "; "; + WriteAsOperand(O, *i, true, getModule()); + O << "\n"; + } + O << "; Trace parent function: \n" << *F; +} + +/// dump - Debugger convenience method; writes trace to standard error +/// output stream. +/// +void Trace::dump() const { + print(cerr); +} diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp new file mode 100644 index 0000000..29ff8aa --- /dev/null +++ b/lib/Analysis/ValueTracking.cpp @@ -0,0 +1,1079 @@ +//===- ValueTracking.cpp - Walk computations to compute properties --------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains routines that help analyze properties that chains of +// computations have. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Constants.h" +#include "llvm/Instructions.h" +#include "llvm/GlobalVariable.h" +#include "llvm/IntrinsicInst.h" +#include "llvm/Target/TargetData.h" +#include "llvm/Support/GetElementPtrTypeIterator.h" +#include "llvm/Support/MathExtras.h" +#include <cstring> +using namespace llvm; + +/// getOpcode - If this is an Instruction or a ConstantExpr, return the +/// opcode value. Otherwise return UserOp1. +static unsigned getOpcode(const Value *V) { + if (const Instruction *I = dyn_cast<Instruction>(V)) + return I->getOpcode(); + if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) + return CE->getOpcode(); + // Use UserOp1 to mean there's no opcode. + return Instruction::UserOp1; +} + + +/// ComputeMaskedBits - Determine which of the bits specified in Mask are +/// known to be either zero or one and return them in the KnownZero/KnownOne +/// bit sets. This code only analyzes bits in Mask, in order to short-circuit +/// processing. +/// NOTE: we cannot consider 'undef' to be "IsZero" here. The problem is that +/// we cannot optimize based on the assumption that it is zero without changing +/// it to be an explicit zero. If we don't change it to zero, other code could +/// optimized based on the contradictory assumption that it is non-zero. +/// Because instcombine aggressively folds operations with undef args anyway, +/// this won't lose us code quality. +void llvm::ComputeMaskedBits(Value *V, const APInt &Mask, + APInt &KnownZero, APInt &KnownOne, + TargetData *TD, unsigned Depth) { + const unsigned MaxDepth = 6; + assert(V && "No Value?"); + assert(Depth <= MaxDepth && "Limit Search Depth"); + unsigned BitWidth = Mask.getBitWidth(); + assert((V->getType()->isInteger() || isa<PointerType>(V->getType())) && + "Not integer or pointer type!"); + assert((!TD || TD->getTypeSizeInBits(V->getType()) == BitWidth) && + (!isa<IntegerType>(V->getType()) || + V->getType()->getPrimitiveSizeInBits() == BitWidth) && + KnownZero.getBitWidth() == BitWidth && + KnownOne.getBitWidth() == BitWidth && + "V, Mask, KnownOne and KnownZero should have same BitWidth"); + + if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { + // We know all of the bits for a constant! + KnownOne = CI->getValue() & Mask; + KnownZero = ~KnownOne & Mask; + return; + } + // Null is all-zeros. + if (isa<ConstantPointerNull>(V)) { + KnownOne.clear(); + KnownZero = Mask; + return; + } + // The address of an aligned GlobalValue has trailing zeros. + if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) { + unsigned Align = GV->getAlignment(); + if (Align == 0 && TD && GV->getType()->getElementType()->isSized()) + Align = TD->getPrefTypeAlignment(GV->getType()->getElementType()); + if (Align > 0) + KnownZero = Mask & APInt::getLowBitsSet(BitWidth, + CountTrailingZeros_32(Align)); + else + KnownZero.clear(); + KnownOne.clear(); + return; + } + + KnownZero.clear(); KnownOne.clear(); // Start out not knowing anything. + + if (Depth == MaxDepth || Mask == 0) + return; // Limit search depth. + + User *I = dyn_cast<User>(V); + if (!I) return; + + APInt KnownZero2(KnownZero), KnownOne2(KnownOne); + switch (getOpcode(I)) { + default: break; + case Instruction::And: { + // If either the LHS or the RHS are Zero, the result is zero. + ComputeMaskedBits(I->getOperand(1), Mask, KnownZero, KnownOne, TD, Depth+1); + APInt Mask2(Mask & ~KnownZero); + ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero2, KnownOne2, TD, + Depth+1); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?"); + + // Output known-1 bits are only known if set in both the LHS & RHS. + KnownOne &= KnownOne2; + // Output known-0 are known to be clear if zero in either the LHS | RHS. + KnownZero |= KnownZero2; + return; + } + case Instruction::Or: { + ComputeMaskedBits(I->getOperand(1), Mask, KnownZero, KnownOne, TD, Depth+1); + APInt Mask2(Mask & ~KnownOne); + ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero2, KnownOne2, TD, + Depth+1); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?"); + + // Output known-0 bits are only known if clear in both the LHS & RHS. + KnownZero &= KnownZero2; + // Output known-1 are known to be set if set in either the LHS | RHS. + KnownOne |= KnownOne2; + return; + } + case Instruction::Xor: { + ComputeMaskedBits(I->getOperand(1), Mask, KnownZero, KnownOne, TD, Depth+1); + ComputeMaskedBits(I->getOperand(0), Mask, KnownZero2, KnownOne2, TD, + Depth+1); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?"); + + // Output known-0 bits are known if clear or set in both the LHS & RHS. + APInt KnownZeroOut = (KnownZero & KnownZero2) | (KnownOne & KnownOne2); + // Output known-1 are known to be set if set in only one of the LHS, RHS. + KnownOne = (KnownZero & KnownOne2) | (KnownOne & KnownZero2); + KnownZero = KnownZeroOut; + return; + } + case Instruction::Mul: { + APInt Mask2 = APInt::getAllOnesValue(BitWidth); + ComputeMaskedBits(I->getOperand(1), Mask2, KnownZero, KnownOne, TD,Depth+1); + ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero2, KnownOne2, TD, + Depth+1); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?"); + + // If low bits are zero in either operand, output low known-0 bits. + // Also compute a conserative estimate for high known-0 bits. + // More trickiness is possible, but this is sufficient for the + // interesting case of alignment computation. + KnownOne.clear(); + unsigned TrailZ = KnownZero.countTrailingOnes() + + KnownZero2.countTrailingOnes(); + unsigned LeadZ = std::max(KnownZero.countLeadingOnes() + + KnownZero2.countLeadingOnes(), + BitWidth) - BitWidth; + + TrailZ = std::min(TrailZ, BitWidth); + LeadZ = std::min(LeadZ, BitWidth); + KnownZero = APInt::getLowBitsSet(BitWidth, TrailZ) | + APInt::getHighBitsSet(BitWidth, LeadZ); + KnownZero &= Mask; + return; + } + case Instruction::UDiv: { + // For the purposes of computing leading zeros we can conservatively + // treat a udiv as a logical right shift by the power of 2 known to + // be less than the denominator. + APInt AllOnes = APInt::getAllOnesValue(BitWidth); + ComputeMaskedBits(I->getOperand(0), + AllOnes, KnownZero2, KnownOne2, TD, Depth+1); + unsigned LeadZ = KnownZero2.countLeadingOnes(); + + KnownOne2.clear(); + KnownZero2.clear(); + ComputeMaskedBits(I->getOperand(1), + AllOnes, KnownZero2, KnownOne2, TD, Depth+1); + unsigned RHSUnknownLeadingOnes = KnownOne2.countLeadingZeros(); + if (RHSUnknownLeadingOnes != BitWidth) + LeadZ = std::min(BitWidth, + LeadZ + BitWidth - RHSUnknownLeadingOnes - 1); + + KnownZero = APInt::getHighBitsSet(BitWidth, LeadZ) & Mask; + return; + } + case Instruction::Select: + ComputeMaskedBits(I->getOperand(2), Mask, KnownZero, KnownOne, TD, Depth+1); + ComputeMaskedBits(I->getOperand(1), Mask, KnownZero2, KnownOne2, TD, + Depth+1); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?"); + + // Only known if known in both the LHS and RHS. + KnownOne &= KnownOne2; + KnownZero &= KnownZero2; + return; + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::SIToFP: + case Instruction::UIToFP: + return; // Can't work with floating point. + case Instruction::PtrToInt: + case Instruction::IntToPtr: + // We can't handle these if we don't know the pointer size. + if (!TD) return; + // FALL THROUGH and handle them the same as zext/trunc. + case Instruction::ZExt: + case Instruction::Trunc: { + // Note that we handle pointer operands here because of inttoptr/ptrtoint + // which fall through here. + const Type *SrcTy = I->getOperand(0)->getType(); + unsigned SrcBitWidth = TD ? + TD->getTypeSizeInBits(SrcTy) : + SrcTy->getPrimitiveSizeInBits(); + APInt MaskIn(Mask); + MaskIn.zextOrTrunc(SrcBitWidth); + KnownZero.zextOrTrunc(SrcBitWidth); + KnownOne.zextOrTrunc(SrcBitWidth); + ComputeMaskedBits(I->getOperand(0), MaskIn, KnownZero, KnownOne, TD, + Depth+1); + KnownZero.zextOrTrunc(BitWidth); + KnownOne.zextOrTrunc(BitWidth); + // Any top bits are known to be zero. + if (BitWidth > SrcBitWidth) + KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); + return; + } + case Instruction::BitCast: { + const Type *SrcTy = I->getOperand(0)->getType(); + if (SrcTy->isInteger() || isa<PointerType>(SrcTy)) { + ComputeMaskedBits(I->getOperand(0), Mask, KnownZero, KnownOne, TD, + Depth+1); + return; + } + break; + } + case Instruction::SExt: { + // Compute the bits in the result that are not present in the input. + const IntegerType *SrcTy = cast<IntegerType>(I->getOperand(0)->getType()); + unsigned SrcBitWidth = SrcTy->getBitWidth(); + + APInt MaskIn(Mask); + MaskIn.trunc(SrcBitWidth); + KnownZero.trunc(SrcBitWidth); + KnownOne.trunc(SrcBitWidth); + ComputeMaskedBits(I->getOperand(0), MaskIn, KnownZero, KnownOne, TD, + Depth+1); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + KnownZero.zext(BitWidth); + KnownOne.zext(BitWidth); + + // If the sign bit of the input is known set or clear, then we know the + // top bits of the result. + if (KnownZero[SrcBitWidth-1]) // Input sign bit known zero + KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); + else if (KnownOne[SrcBitWidth-1]) // Input sign bit known set + KnownOne |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); + return; + } + case Instruction::Shl: + // (shl X, C1) & C2 == 0 iff (X & C2 >>u C1) == 0 + if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { + uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); + APInt Mask2(Mask.lshr(ShiftAmt)); + ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero, KnownOne, TD, + Depth+1); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + KnownZero <<= ShiftAmt; + KnownOne <<= ShiftAmt; + KnownZero |= APInt::getLowBitsSet(BitWidth, ShiftAmt); // low bits known 0 + return; + } + break; + case Instruction::LShr: + // (ushr X, C1) & C2 == 0 iff (-1 >> C1) & C2 == 0 + if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { + // Compute the new bits that are at the top now. + uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); + + // Unsigned shift right. + APInt Mask2(Mask.shl(ShiftAmt)); + ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero,KnownOne, TD, + Depth+1); + assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); + KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); + KnownOne = APIntOps::lshr(KnownOne, ShiftAmt); + // high bits known zero. + KnownZero |= APInt::getHighBitsSet(BitWidth, ShiftAmt); + return; + } + break; + case Instruction::AShr: + // (ashr X, C1) & C2 == 0 iff (-1 >> C1) & C2 == 0 + if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { + // Compute the new bits that are at the top now. + uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); + + // Signed shift right. + APInt Mask2(Mask.shl(ShiftAmt)); + ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero, KnownOne, TD, + Depth+1); + assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); + KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); + KnownOne = APIntOps::lshr(KnownOne, ShiftAmt); + + APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); + if (KnownZero[BitWidth-ShiftAmt-1]) // New bits are known zero. + KnownZero |= HighBits; + else if (KnownOne[BitWidth-ShiftAmt-1]) // New bits are known one. + KnownOne |= HighBits; + return; + } + break; + case Instruction::Sub: { + if (ConstantInt *CLHS = dyn_cast<ConstantInt>(I->getOperand(0))) { + // We know that the top bits of C-X are clear if X contains less bits + // than C (i.e. no wrap-around can happen). For example, 20-X is + // positive if we can prove that X is >= 0 and < 16. + if (!CLHS->getValue().isNegative()) { + unsigned NLZ = (CLHS->getValue()+1).countLeadingZeros(); + // NLZ can't be BitWidth with no sign bit + APInt MaskV = APInt::getHighBitsSet(BitWidth, NLZ+1); + ComputeMaskedBits(I->getOperand(1), MaskV, KnownZero2, KnownOne2, + TD, Depth+1); + + // If all of the MaskV bits are known to be zero, then we know the + // output top bits are zero, because we now know that the output is + // from [0-C]. + if ((KnownZero2 & MaskV) == MaskV) { + unsigned NLZ2 = CLHS->getValue().countLeadingZeros(); + // Top bits known zero. + KnownZero = APInt::getHighBitsSet(BitWidth, NLZ2) & Mask; + } + } + } + } + // fall through + case Instruction::Add: { + // If one of the operands has trailing zeros, than the bits that the + // other operand has in those bit positions will be preserved in the + // result. For an add, this works with either operand. For a subtract, + // this only works if the known zeros are in the right operand. + APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); + APInt Mask2 = APInt::getLowBitsSet(BitWidth, + BitWidth - Mask.countLeadingZeros()); + ComputeMaskedBits(I->getOperand(0), Mask2, LHSKnownZero, LHSKnownOne, TD, + Depth+1); + assert((LHSKnownZero & LHSKnownOne) == 0 && + "Bits known to be one AND zero?"); + unsigned LHSKnownZeroOut = LHSKnownZero.countTrailingOnes(); + + ComputeMaskedBits(I->getOperand(1), Mask2, KnownZero2, KnownOne2, TD, + Depth+1); + assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?"); + unsigned RHSKnownZeroOut = KnownZero2.countTrailingOnes(); + + // Determine which operand has more trailing zeros, and use that + // many bits from the other operand. + if (LHSKnownZeroOut > RHSKnownZeroOut) { + if (getOpcode(I) == Instruction::Add) { + APInt Mask = APInt::getLowBitsSet(BitWidth, LHSKnownZeroOut); + KnownZero |= KnownZero2 & Mask; + KnownOne |= KnownOne2 & Mask; + } else { + // If the known zeros are in the left operand for a subtract, + // fall back to the minimum known zeros in both operands. + KnownZero |= APInt::getLowBitsSet(BitWidth, + std::min(LHSKnownZeroOut, + RHSKnownZeroOut)); + } + } else if (RHSKnownZeroOut >= LHSKnownZeroOut) { + APInt Mask = APInt::getLowBitsSet(BitWidth, RHSKnownZeroOut); + KnownZero |= LHSKnownZero & Mask; + KnownOne |= LHSKnownOne & Mask; + } + return; + } + case Instruction::SRem: + if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) { + APInt RA = Rem->getValue(); + if (RA.isPowerOf2() || (-RA).isPowerOf2()) { + APInt LowBits = RA.isStrictlyPositive() ? (RA - 1) : ~RA; + APInt Mask2 = LowBits | APInt::getSignBit(BitWidth); + ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero2, KnownOne2, TD, + Depth+1); + + // If the sign bit of the first operand is zero, the sign bit of + // the result is zero. If the first operand has no one bits below + // the second operand's single 1 bit, its sign will be zero. + if (KnownZero2[BitWidth-1] || ((KnownZero2 & LowBits) == LowBits)) + KnownZero2 |= ~LowBits; + + KnownZero |= KnownZero2 & Mask; + + assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); + } + } + break; + case Instruction::URem: { + if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) { + APInt RA = Rem->getValue(); + if (RA.isPowerOf2()) { + APInt LowBits = (RA - 1); + APInt Mask2 = LowBits & Mask; + KnownZero |= ~LowBits & Mask; + ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero, KnownOne, TD, + Depth+1); + assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); + break; + } + } + + // Since the result is less than or equal to either operand, any leading + // zero bits in either operand must also exist in the result. + APInt AllOnes = APInt::getAllOnesValue(BitWidth); + ComputeMaskedBits(I->getOperand(0), AllOnes, KnownZero, KnownOne, + TD, Depth+1); + ComputeMaskedBits(I->getOperand(1), AllOnes, KnownZero2, KnownOne2, + TD, Depth+1); + + unsigned Leaders = std::max(KnownZero.countLeadingOnes(), + KnownZero2.countLeadingOnes()); + KnownOne.clear(); + KnownZero = APInt::getHighBitsSet(BitWidth, Leaders) & Mask; + break; + } + + case Instruction::Alloca: + case Instruction::Malloc: { + AllocationInst *AI = cast<AllocationInst>(V); + unsigned Align = AI->getAlignment(); + if (Align == 0 && TD) { + if (isa<AllocaInst>(AI)) + Align = TD->getABITypeAlignment(AI->getType()->getElementType()); + else if (isa<MallocInst>(AI)) { + // Malloc returns maximally aligned memory. + Align = TD->getABITypeAlignment(AI->getType()->getElementType()); + Align = + std::max(Align, + (unsigned)TD->getABITypeAlignment(Type::DoubleTy)); + Align = + std::max(Align, + (unsigned)TD->getABITypeAlignment(Type::Int64Ty)); + } + } + + if (Align > 0) + KnownZero = Mask & APInt::getLowBitsSet(BitWidth, + CountTrailingZeros_32(Align)); + break; + } + case Instruction::GetElementPtr: { + // Analyze all of the subscripts of this getelementptr instruction + // to determine if we can prove known low zero bits. + APInt LocalMask = APInt::getAllOnesValue(BitWidth); + APInt LocalKnownZero(BitWidth, 0), LocalKnownOne(BitWidth, 0); + ComputeMaskedBits(I->getOperand(0), LocalMask, + LocalKnownZero, LocalKnownOne, TD, Depth+1); + unsigned TrailZ = LocalKnownZero.countTrailingOnes(); + + gep_type_iterator GTI = gep_type_begin(I); + for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i, ++GTI) { + Value *Index = I->getOperand(i); + if (const StructType *STy = dyn_cast<StructType>(*GTI)) { + // Handle struct member offset arithmetic. + if (!TD) return; + const StructLayout *SL = TD->getStructLayout(STy); + unsigned Idx = cast<ConstantInt>(Index)->getZExtValue(); + uint64_t Offset = SL->getElementOffset(Idx); + TrailZ = std::min(TrailZ, + CountTrailingZeros_64(Offset)); + } else { + // Handle array index arithmetic. + const Type *IndexedTy = GTI.getIndexedType(); + if (!IndexedTy->isSized()) return; + unsigned GEPOpiBits = Index->getType()->getPrimitiveSizeInBits(); + uint64_t TypeSize = TD ? TD->getTypeAllocSize(IndexedTy) : 1; + LocalMask = APInt::getAllOnesValue(GEPOpiBits); + LocalKnownZero = LocalKnownOne = APInt(GEPOpiBits, 0); + ComputeMaskedBits(Index, LocalMask, + LocalKnownZero, LocalKnownOne, TD, Depth+1); + TrailZ = std::min(TrailZ, + unsigned(CountTrailingZeros_64(TypeSize) + + LocalKnownZero.countTrailingOnes())); + } + } + + KnownZero = APInt::getLowBitsSet(BitWidth, TrailZ) & Mask; + break; + } + case Instruction::PHI: { + PHINode *P = cast<PHINode>(I); + // Handle the case of a simple two-predecessor recurrence PHI. + // There's a lot more that could theoretically be done here, but + // this is sufficient to catch some interesting cases. + if (P->getNumIncomingValues() == 2) { + for (unsigned i = 0; i != 2; ++i) { + Value *L = P->getIncomingValue(i); + Value *R = P->getIncomingValue(!i); + User *LU = dyn_cast<User>(L); + if (!LU) + continue; + unsigned Opcode = getOpcode(LU); + // Check for operations that have the property that if + // both their operands have low zero bits, the result + // will have low zero bits. + if (Opcode == Instruction::Add || + Opcode == Instruction::Sub || + Opcode == Instruction::And || + Opcode == Instruction::Or || + Opcode == Instruction::Mul) { + Value *LL = LU->getOperand(0); + Value *LR = LU->getOperand(1); + // Find a recurrence. + if (LL == I) + L = LR; + else if (LR == I) + L = LL; + else + break; + // Ok, we have a PHI of the form L op= R. Check for low + // zero bits. + APInt Mask2 = APInt::getAllOnesValue(BitWidth); + ComputeMaskedBits(R, Mask2, KnownZero2, KnownOne2, TD, Depth+1); + Mask2 = APInt::getLowBitsSet(BitWidth, + KnownZero2.countTrailingOnes()); + + // We need to take the minimum number of known bits + APInt KnownZero3(KnownZero), KnownOne3(KnownOne); + ComputeMaskedBits(L, Mask2, KnownZero3, KnownOne3, TD, Depth+1); + + KnownZero = Mask & + APInt::getLowBitsSet(BitWidth, + std::min(KnownZero2.countTrailingOnes(), + KnownZero3.countTrailingOnes())); + break; + } + } + } + + // Otherwise take the unions of the known bit sets of the operands, + // taking conservative care to avoid excessive recursion. + if (Depth < MaxDepth - 1 && !KnownZero && !KnownOne) { + KnownZero = APInt::getAllOnesValue(BitWidth); + KnownOne = APInt::getAllOnesValue(BitWidth); + for (unsigned i = 0, e = P->getNumIncomingValues(); i != e; ++i) { + // Skip direct self references. + if (P->getIncomingValue(i) == P) continue; + + KnownZero2 = APInt(BitWidth, 0); + KnownOne2 = APInt(BitWidth, 0); + // Recurse, but cap the recursion to one level, because we don't + // want to waste time spinning around in loops. + ComputeMaskedBits(P->getIncomingValue(i), KnownZero | KnownOne, + KnownZero2, KnownOne2, TD, MaxDepth-1); + KnownZero &= KnownZero2; + KnownOne &= KnownOne2; + // If all bits have been ruled out, there's no need to check + // more operands. + if (!KnownZero && !KnownOne) + break; + } + } + break; + } + case Instruction::Call: + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + default: break; + case Intrinsic::ctpop: + case Intrinsic::ctlz: + case Intrinsic::cttz: { + unsigned LowBits = Log2_32(BitWidth)+1; + KnownZero = APInt::getHighBitsSet(BitWidth, BitWidth - LowBits); + break; + } + } + } + break; + } +} + +/// MaskedValueIsZero - Return true if 'V & Mask' is known to be zero. We use +/// this predicate to simplify operations downstream. Mask is known to be zero +/// for bits that V cannot have. +bool llvm::MaskedValueIsZero(Value *V, const APInt &Mask, + TargetData *TD, unsigned Depth) { + APInt KnownZero(Mask.getBitWidth(), 0), KnownOne(Mask.getBitWidth(), 0); + ComputeMaskedBits(V, Mask, KnownZero, KnownOne, TD, Depth); + assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); + return (KnownZero & Mask) == Mask; +} + + + +/// ComputeNumSignBits - Return the number of times the sign bit of the +/// register is replicated into the other bits. We know that at least 1 bit +/// is always equal to the sign bit (itself), but other cases can give us +/// information. For example, immediately after an "ashr X, 2", we know that +/// the top 3 bits are all equal to each other, so we return 3. +/// +/// 'Op' must have a scalar integer type. +/// +unsigned llvm::ComputeNumSignBits(Value *V, TargetData *TD, unsigned Depth) { + const IntegerType *Ty = cast<IntegerType>(V->getType()); + unsigned TyBits = Ty->getBitWidth(); + unsigned Tmp, Tmp2; + unsigned FirstAnswer = 1; + + // Note that ConstantInt is handled by the general ComputeMaskedBits case + // below. + + if (Depth == 6) + return 1; // Limit search depth. + + User *U = dyn_cast<User>(V); + switch (getOpcode(V)) { + default: break; + case Instruction::SExt: + Tmp = TyBits-cast<IntegerType>(U->getOperand(0)->getType())->getBitWidth(); + return ComputeNumSignBits(U->getOperand(0), TD, Depth+1) + Tmp; + + case Instruction::AShr: + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + // ashr X, C -> adds C sign bits. + if (ConstantInt *C = dyn_cast<ConstantInt>(U->getOperand(1))) { + Tmp += C->getZExtValue(); + if (Tmp > TyBits) Tmp = TyBits; + } + return Tmp; + case Instruction::Shl: + if (ConstantInt *C = dyn_cast<ConstantInt>(U->getOperand(1))) { + // shl destroys sign bits. + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + if (C->getZExtValue() >= TyBits || // Bad shift. + C->getZExtValue() >= Tmp) break; // Shifted all sign bits out. + return Tmp - C->getZExtValue(); + } + break; + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: // NOT is handled here. + // Logical binary ops preserve the number of sign bits at the worst. + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + if (Tmp != 1) { + Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + FirstAnswer = std::min(Tmp, Tmp2); + // We computed what we know about the sign bits as our first + // answer. Now proceed to the generic code that uses + // ComputeMaskedBits, and pick whichever answer is better. + } + break; + + case Instruction::Select: + Tmp = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + if (Tmp == 1) return 1; // Early out. + Tmp2 = ComputeNumSignBits(U->getOperand(2), TD, Depth+1); + return std::min(Tmp, Tmp2); + + case Instruction::Add: + // Add can have at most one carry bit. Thus we know that the output + // is, at worst, one more bit than the inputs. + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + if (Tmp == 1) return 1; // Early out. + + // Special case decrementing a value (ADD X, -1): + if (ConstantInt *CRHS = dyn_cast<ConstantInt>(U->getOperand(1))) + if (CRHS->isAllOnesValue()) { + APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); + APInt Mask = APInt::getAllOnesValue(TyBits); + ComputeMaskedBits(U->getOperand(0), Mask, KnownZero, KnownOne, TD, + Depth+1); + + // If the input is known to be 0 or 1, the output is 0/-1, which is all + // sign bits set. + if ((KnownZero | APInt(TyBits, 1)) == Mask) + return TyBits; + + // If we are subtracting one from a positive number, there is no carry + // out of the result. + if (KnownZero.isNegative()) + return Tmp; + } + + Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + if (Tmp2 == 1) return 1; + return std::min(Tmp, Tmp2)-1; + break; + + case Instruction::Sub: + Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + if (Tmp2 == 1) return 1; + + // Handle NEG. + if (ConstantInt *CLHS = dyn_cast<ConstantInt>(U->getOperand(0))) + if (CLHS->isNullValue()) { + APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); + APInt Mask = APInt::getAllOnesValue(TyBits); + ComputeMaskedBits(U->getOperand(1), Mask, KnownZero, KnownOne, + TD, Depth+1); + // If the input is known to be 0 or 1, the output is 0/-1, which is all + // sign bits set. + if ((KnownZero | APInt(TyBits, 1)) == Mask) + return TyBits; + + // If the input is known to be positive (the sign bit is known clear), + // the output of the NEG has the same number of sign bits as the input. + if (KnownZero.isNegative()) + return Tmp2; + + // Otherwise, we treat this like a SUB. + } + + // Sub can have at most one carry bit. Thus we know that the output + // is, at worst, one more bit than the inputs. + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + if (Tmp == 1) return 1; // Early out. + return std::min(Tmp, Tmp2)-1; + break; + case Instruction::Trunc: + // FIXME: it's tricky to do anything useful for this, but it is an important + // case for targets like X86. + break; + } + + // Finally, if we can prove that the top bits of the result are 0's or 1's, + // use this information. + APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); + APInt Mask = APInt::getAllOnesValue(TyBits); + ComputeMaskedBits(V, Mask, KnownZero, KnownOne, TD, Depth); + + if (KnownZero.isNegative()) { // sign bit is 0 + Mask = KnownZero; + } else if (KnownOne.isNegative()) { // sign bit is 1; + Mask = KnownOne; + } else { + // Nothing known. + return FirstAnswer; + } + + // Okay, we know that the sign bit in Mask is set. Use CLZ to determine + // the number of identical bits in the top of the input value. + Mask = ~Mask; + Mask <<= Mask.getBitWidth()-TyBits; + // Return # leading zeros. We use 'min' here in case Val was zero before + // shifting. We don't want to return '64' as for an i32 "0". + return std::max(FirstAnswer, std::min(TyBits, Mask.countLeadingZeros())); +} + +/// CannotBeNegativeZero - Return true if we can prove that the specified FP +/// value is never equal to -0.0. +/// +/// NOTE: this function will need to be revisited when we support non-default +/// rounding modes! +/// +bool llvm::CannotBeNegativeZero(const Value *V, unsigned Depth) { + if (const ConstantFP *CFP = dyn_cast<ConstantFP>(V)) + return !CFP->getValueAPF().isNegZero(); + + if (Depth == 6) + return 1; // Limit search depth. + + const Instruction *I = dyn_cast<Instruction>(V); + if (I == 0) return false; + + // (add x, 0.0) is guaranteed to return +0.0, not -0.0. + if (I->getOpcode() == Instruction::Add && + isa<ConstantFP>(I->getOperand(1)) && + cast<ConstantFP>(I->getOperand(1))->isNullValue()) + return true; + + // sitofp and uitofp turn into +0.0 for zero. + if (isa<SIToFPInst>(I) || isa<UIToFPInst>(I)) + return true; + + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) + // sqrt(-0.0) = -0.0, no other negative results are possible. + if (II->getIntrinsicID() == Intrinsic::sqrt) + return CannotBeNegativeZero(II->getOperand(1), Depth+1); + + if (const CallInst *CI = dyn_cast<CallInst>(I)) + if (const Function *F = CI->getCalledFunction()) { + if (F->isDeclaration()) { + switch (F->getNameLen()) { + case 3: // abs(x) != -0.0 + if (!strcmp(F->getNameStart(), "abs")) return true; + break; + case 4: // abs[lf](x) != -0.0 + if (!strcmp(F->getNameStart(), "absf")) return true; + if (!strcmp(F->getNameStart(), "absl")) return true; + break; + } + } + } + + return false; +} + +// This is the recursive version of BuildSubAggregate. It takes a few different +// arguments. Idxs is the index within the nested struct From that we are +// looking at now (which is of type IndexedType). IdxSkip is the number of +// indices from Idxs that should be left out when inserting into the resulting +// struct. To is the result struct built so far, new insertvalue instructions +// build on that. +Value *BuildSubAggregate(Value *From, Value* To, const Type *IndexedType, + SmallVector<unsigned, 10> &Idxs, + unsigned IdxSkip, + Instruction *InsertBefore) { + const llvm::StructType *STy = llvm::dyn_cast<llvm::StructType>(IndexedType); + if (STy) { + // Save the original To argument so we can modify it + Value *OrigTo = To; + // General case, the type indexed by Idxs is a struct + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + // Process each struct element recursively + Idxs.push_back(i); + Value *PrevTo = To; + To = BuildSubAggregate(From, To, STy->getElementType(i), Idxs, IdxSkip, + InsertBefore); + Idxs.pop_back(); + if (!To) { + // Couldn't find any inserted value for this index? Cleanup + while (PrevTo != OrigTo) { + InsertValueInst* Del = cast<InsertValueInst>(PrevTo); + PrevTo = Del->getAggregateOperand(); + Del->eraseFromParent(); + } + // Stop processing elements + break; + } + } + // If we succesfully found a value for each of our subaggregates + if (To) + return To; + } + // Base case, the type indexed by SourceIdxs is not a struct, or not all of + // the struct's elements had a value that was inserted directly. In the latter + // case, perhaps we can't determine each of the subelements individually, but + // we might be able to find the complete struct somewhere. + + // Find the value that is at that particular spot + Value *V = FindInsertedValue(From, Idxs.begin(), Idxs.end()); + + if (!V) + return NULL; + + // Insert the value in the new (sub) aggregrate + return llvm::InsertValueInst::Create(To, V, Idxs.begin() + IdxSkip, + Idxs.end(), "tmp", InsertBefore); +} + +// This helper takes a nested struct and extracts a part of it (which is again a +// struct) into a new value. For example, given the struct: +// { a, { b, { c, d }, e } } +// and the indices "1, 1" this returns +// { c, d }. +// +// It does this by inserting an insertvalue for each element in the resulting +// struct, as opposed to just inserting a single struct. This will only work if +// each of the elements of the substruct are known (ie, inserted into From by an +// insertvalue instruction somewhere). +// +// All inserted insertvalue instructions are inserted before InsertBefore +Value *BuildSubAggregate(Value *From, const unsigned *idx_begin, + const unsigned *idx_end, Instruction *InsertBefore) { + assert(InsertBefore && "Must have someplace to insert!"); + const Type *IndexedType = ExtractValueInst::getIndexedType(From->getType(), + idx_begin, + idx_end); + Value *To = UndefValue::get(IndexedType); + SmallVector<unsigned, 10> Idxs(idx_begin, idx_end); + unsigned IdxSkip = Idxs.size(); + + return BuildSubAggregate(From, To, IndexedType, Idxs, IdxSkip, InsertBefore); +} + +/// FindInsertedValue - Given an aggregrate and an sequence of indices, see if +/// the scalar value indexed is already around as a register, for example if it +/// were inserted directly into the aggregrate. +/// +/// If InsertBefore is not null, this function will duplicate (modified) +/// insertvalues when a part of a nested struct is extracted. +Value *llvm::FindInsertedValue(Value *V, const unsigned *idx_begin, + const unsigned *idx_end, Instruction *InsertBefore) { + // Nothing to index? Just return V then (this is useful at the end of our + // recursion) + if (idx_begin == idx_end) + return V; + // We have indices, so V should have an indexable type + assert((isa<StructType>(V->getType()) || isa<ArrayType>(V->getType())) + && "Not looking at a struct or array?"); + assert(ExtractValueInst::getIndexedType(V->getType(), idx_begin, idx_end) + && "Invalid indices for type?"); + const CompositeType *PTy = cast<CompositeType>(V->getType()); + + if (isa<UndefValue>(V)) + return UndefValue::get(ExtractValueInst::getIndexedType(PTy, + idx_begin, + idx_end)); + else if (isa<ConstantAggregateZero>(V)) + return Constant::getNullValue(ExtractValueInst::getIndexedType(PTy, + idx_begin, + idx_end)); + else if (Constant *C = dyn_cast<Constant>(V)) { + if (isa<ConstantArray>(C) || isa<ConstantStruct>(C)) + // Recursively process this constant + return FindInsertedValue(C->getOperand(*idx_begin), idx_begin + 1, idx_end, + InsertBefore); + } else if (InsertValueInst *I = dyn_cast<InsertValueInst>(V)) { + // Loop the indices for the insertvalue instruction in parallel with the + // requested indices + const unsigned *req_idx = idx_begin; + for (const unsigned *i = I->idx_begin(), *e = I->idx_end(); + i != e; ++i, ++req_idx) { + if (req_idx == idx_end) { + if (InsertBefore) + // The requested index identifies a part of a nested aggregate. Handle + // this specially. For example, + // %A = insertvalue { i32, {i32, i32 } } undef, i32 10, 1, 0 + // %B = insertvalue { i32, {i32, i32 } } %A, i32 11, 1, 1 + // %C = extractvalue {i32, { i32, i32 } } %B, 1 + // This can be changed into + // %A = insertvalue {i32, i32 } undef, i32 10, 0 + // %C = insertvalue {i32, i32 } %A, i32 11, 1 + // which allows the unused 0,0 element from the nested struct to be + // removed. + return BuildSubAggregate(V, idx_begin, req_idx, InsertBefore); + else + // We can't handle this without inserting insertvalues + return 0; + } + + // This insert value inserts something else than what we are looking for. + // See if the (aggregrate) value inserted into has the value we are + // looking for, then. + if (*req_idx != *i) + return FindInsertedValue(I->getAggregateOperand(), idx_begin, idx_end, + InsertBefore); + } + // If we end up here, the indices of the insertvalue match with those + // requested (though possibly only partially). Now we recursively look at + // the inserted value, passing any remaining indices. + return FindInsertedValue(I->getInsertedValueOperand(), req_idx, idx_end, + InsertBefore); + } else if (ExtractValueInst *I = dyn_cast<ExtractValueInst>(V)) { + // If we're extracting a value from an aggregrate that was extracted from + // something else, we can extract from that something else directly instead. + // However, we will need to chain I's indices with the requested indices. + + // Calculate the number of indices required + unsigned size = I->getNumIndices() + (idx_end - idx_begin); + // Allocate some space to put the new indices in + SmallVector<unsigned, 5> Idxs; + Idxs.reserve(size); + // Add indices from the extract value instruction + for (const unsigned *i = I->idx_begin(), *e = I->idx_end(); + i != e; ++i) + Idxs.push_back(*i); + + // Add requested indices + for (const unsigned *i = idx_begin, *e = idx_end; i != e; ++i) + Idxs.push_back(*i); + + assert(Idxs.size() == size + && "Number of indices added not correct?"); + + return FindInsertedValue(I->getAggregateOperand(), Idxs.begin(), Idxs.end(), + InsertBefore); + } + // Otherwise, we don't know (such as, extracting from a function return value + // or load instruction) + return 0; +} + +/// GetConstantStringInfo - This function computes the length of a +/// null-terminated C string pointed to by V. If successful, it returns true +/// and returns the string in Str. If unsuccessful, it returns false. +bool llvm::GetConstantStringInfo(Value *V, std::string &Str, uint64_t Offset, + bool StopAtNul) { + // If V is NULL then return false; + if (V == NULL) return false; + + // Look through bitcast instructions. + if (BitCastInst *BCI = dyn_cast<BitCastInst>(V)) + return GetConstantStringInfo(BCI->getOperand(0), Str, Offset, StopAtNul); + + // If the value is not a GEP instruction nor a constant expression with a + // GEP instruction, then return false because ConstantArray can't occur + // any other way + User *GEP = 0; + if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(V)) { + GEP = GEPI; + } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { + if (CE->getOpcode() == Instruction::BitCast) + return GetConstantStringInfo(CE->getOperand(0), Str, Offset, StopAtNul); + if (CE->getOpcode() != Instruction::GetElementPtr) + return false; + GEP = CE; + } + + if (GEP) { + // Make sure the GEP has exactly three arguments. + if (GEP->getNumOperands() != 3) + return false; + + // Make sure the index-ee is a pointer to array of i8. + const PointerType *PT = cast<PointerType>(GEP->getOperand(0)->getType()); + const ArrayType *AT = dyn_cast<ArrayType>(PT->getElementType()); + if (AT == 0 || AT->getElementType() != Type::Int8Ty) + return false; + + // Check to make sure that the first operand of the GEP is an integer and + // has value 0 so that we are sure we're indexing into the initializer. + ConstantInt *FirstIdx = dyn_cast<ConstantInt>(GEP->getOperand(1)); + if (FirstIdx == 0 || !FirstIdx->isZero()) + return false; + + // If the second index isn't a ConstantInt, then this is a variable index + // into the array. If this occurs, we can't say anything meaningful about + // the string. + uint64_t StartIdx = 0; + if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(2))) + StartIdx = CI->getZExtValue(); + else + return false; + return GetConstantStringInfo(GEP->getOperand(0), Str, StartIdx+Offset, + StopAtNul); + } + + // The GEP instruction, constant or instruction, must reference a global + // variable that is a constant and is initialized. The referenced constant + // initializer is the array that we'll use for optimization. + GlobalVariable* GV = dyn_cast<GlobalVariable>(V); + if (!GV || !GV->isConstant() || !GV->hasInitializer()) + return false; + Constant *GlobalInit = GV->getInitializer(); + + // Handle the ConstantAggregateZero case + if (isa<ConstantAggregateZero>(GlobalInit)) { + // This is a degenerate case. The initializer is constant zero so the + // length of the string must be zero. + Str.clear(); + return true; + } + + // Must be a Constant Array + ConstantArray *Array = dyn_cast<ConstantArray>(GlobalInit); + if (Array == 0 || Array->getType()->getElementType() != Type::Int8Ty) + return false; + + // Get the number of elements in the array + uint64_t NumElts = Array->getType()->getNumElements(); + + if (Offset > NumElts) + return false; + + // Traverse the constant array from 'Offset' which is the place the GEP refers + // to in the array. + Str.reserve(NumElts-Offset); + for (unsigned i = Offset; i != NumElts; ++i) { + Constant *Elt = Array->getOperand(i); + ConstantInt *CI = dyn_cast<ConstantInt>(Elt); + if (!CI) // This array isn't suitable, non-int initializer. + return false; + if (StopAtNul && CI->isZero()) + return true; // we found end of string, success! + Str += (char)CI->getZExtValue(); + } + + // The array isn't null terminated, but maybe this is a memcpy, not a strcpy. + return true; +} |