diff options
Diffstat (limited to 'contrib/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp')
-rw-r--r-- | contrib/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp | 456 |
1 files changed, 381 insertions, 75 deletions
diff --git a/contrib/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/contrib/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index f54d8ad..04f9a64 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -51,6 +51,7 @@ #include "llvm/Transforms/PGOInstrumentation.h" #include "CFGMST.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Triple.h" #include "llvm/Analysis/BlockFrequencyInfo.h" @@ -59,6 +60,7 @@ #include "llvm/Analysis/IndirectCallSiteVisitor.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" @@ -75,6 +77,7 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include <algorithm> #include <string> +#include <unordered_map> #include <utility> #include <vector> @@ -83,6 +86,7 @@ using namespace llvm; #define DEBUG_TYPE "pgo-instrumentation" STATISTIC(NumOfPGOInstrument, "Number of edges instrumented."); +STATISTIC(NumOfPGOSelectInsts, "Number of select instruction instrumented."); STATISTIC(NumOfPGOEdge, "Number of edges."); STATISTIC(NumOfPGOBB, "Number of basic-blocks."); STATISTIC(NumOfPGOSplit, "Number of critical edge splits."); @@ -112,17 +116,89 @@ static cl::opt<unsigned> MaxNumAnnotations( cl::desc("Max number of annotations for a single indirect " "call callsite")); +// Command line option to control appending FunctionHash to the name of a COMDAT +// function. This is to avoid the hash mismatch caused by the preinliner. +static cl::opt<bool> DoComdatRenaming( + "do-comdat-renaming", cl::init(false), cl::Hidden, + cl::desc("Append function hash to the name of COMDAT function to avoid " + "function hash mismatch due to the preinliner")); + // Command line option to enable/disable the warning about missing profile // information. -static cl::opt<bool> NoPGOWarnMissing("no-pgo-warn-missing", cl::init(false), - cl::Hidden); +static cl::opt<bool> PGOWarnMissing("pgo-warn-missing-function", + cl::init(false), + cl::Hidden); // Command line option to enable/disable the warning about a hash mismatch in // the profile data. static cl::opt<bool> NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false), cl::Hidden); +// Command line option to enable/disable the warning about a hash mismatch in +// the profile data for Comdat functions, which often turns out to be false +// positive due to the pre-instrumentation inline. +static cl::opt<bool> NoPGOWarnMismatchComdat("no-pgo-warn-mismatch-comdat", + cl::init(true), cl::Hidden); + +// Command line option to enable/disable select instruction instrumentation. +static cl::opt<bool> PGOInstrSelect("pgo-instr-select", cl::init(true), + cl::Hidden); namespace { + +/// The select instruction visitor plays three roles specified +/// by the mode. In \c VM_counting mode, it simply counts the number of +/// select instructions. In \c VM_instrument mode, it inserts code to count +/// the number times TrueValue of select is taken. In \c VM_annotate mode, +/// it reads the profile data and annotate the select instruction with metadata. +enum VisitMode { VM_counting, VM_instrument, VM_annotate }; +class PGOUseFunc; + +/// Instruction Visitor class to visit select instructions. +struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> { + Function &F; + unsigned NSIs = 0; // Number of select instructions instrumented. + VisitMode Mode = VM_counting; // Visiting mode. + unsigned *CurCtrIdx = nullptr; // Pointer to current counter index. + unsigned TotalNumCtrs = 0; // Total number of counters + GlobalVariable *FuncNameVar = nullptr; + uint64_t FuncHash = 0; + PGOUseFunc *UseFunc = nullptr; + + SelectInstVisitor(Function &Func) : F(Func) {} + + void countSelects(Function &Func) { + Mode = VM_counting; + visit(Func); + } + // Visit the IR stream and instrument all select instructions. \p + // Ind is a pointer to the counter index variable; \p TotalNC + // is the total number of counters; \p FNV is the pointer to the + // PGO function name var; \p FHash is the function hash. + void instrumentSelects(Function &Func, unsigned *Ind, unsigned TotalNC, + GlobalVariable *FNV, uint64_t FHash) { + Mode = VM_instrument; + CurCtrIdx = Ind; + TotalNumCtrs = TotalNC; + FuncHash = FHash; + FuncNameVar = FNV; + visit(Func); + } + + // Visit the IR stream and annotate all select instructions. + void annotateSelects(Function &Func, PGOUseFunc *UF, unsigned *Ind) { + Mode = VM_annotate; + UseFunc = UF; + CurCtrIdx = Ind; + visit(Func); + } + + void instrumentOneSelectInst(SelectInst &SI); + void annotateOneSelectInst(SelectInst &SI); + // Visit \p SI instruction and perform tasks according to visit mode. + void visitSelectInst(SelectInst &SI); + unsigned getNumOfSelectInsts() const { return NSIs; } +}; + class PGOInstrumentationGenLegacyPass : public ModulePass { public: static char ID; @@ -132,9 +208,7 @@ public: *PassRegistry::getPassRegistry()); } - const char *getPassName() const override { - return "PGOInstrumentationGenPass"; - } + StringRef getPassName() const override { return "PGOInstrumentationGenPass"; } private: bool runOnModule(Module &M) override; @@ -157,9 +231,7 @@ public: *PassRegistry::getPassRegistry()); } - const char *getPassName() const override { - return "PGOInstrumentationUsePass"; - } + StringRef getPassName() const override { return "PGOInstrumentationUsePass"; } private: std::string ProfileFileName; @@ -169,6 +241,7 @@ private: AU.addRequired<BlockFrequencyInfoWrapperPass>(); } }; + } // end anonymous namespace char PGOInstrumentationGenLegacyPass::ID = 0; @@ -238,8 +311,13 @@ template <class Edge, class BBInfo> class FuncPGOInstrumentation { private: Function &F; void computeCFGHash(); + void renameComdatFunction(); + // A map that stores the Comdat group in function F. + std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers; public: + std::vector<Instruction *> IndirectCallSites; + SelectInstVisitor SIVisitor; std::string FuncName; GlobalVariable *FuncNameVar; // CFG hash value for this function. @@ -255,18 +333,32 @@ public: // Return the auxiliary BB information. BBInfo &getBBInfo(const BasicBlock *BB) const { return MST.getBBInfo(BB); } + // Return the auxiliary BB information if available. + BBInfo *findBBInfo(const BasicBlock *BB) const { return MST.findBBInfo(BB); } + // Dump edges and BB information. void dumpInfo(std::string Str = "") const { MST.dumpEdges(dbgs(), Twine("Dump Function ") + FuncName + " Hash: " + Twine(FunctionHash) + "\t" + Str); } - FuncPGOInstrumentation(Function &Func, bool CreateGlobalVar = false, - BranchProbabilityInfo *BPI = nullptr, - BlockFrequencyInfo *BFI = nullptr) - : F(Func), FunctionHash(0), MST(F, BPI, BFI) { + FuncPGOInstrumentation( + Function &Func, + std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers, + bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr, + BlockFrequencyInfo *BFI = nullptr) + : F(Func), ComdatMembers(ComdatMembers), SIVisitor(Func), FunctionHash(0), + MST(F, BPI, BFI) { + + // This should be done before CFG hash computation. + SIVisitor.countSelects(Func); + NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts(); + IndirectCallSites = findIndirectCallSites(Func); + FuncName = getPGOFuncName(F); computeCFGHash(); + if (ComdatMembers.size()) + renameComdatFunction(); DEBUG(dumpInfo("after CFGMST")); NumOfPGOBB += MST.BBInfos.size(); @@ -281,6 +373,16 @@ public: if (CreateGlobalVar) FuncNameVar = createPGOFuncNameVar(F, FuncName); } + + // Return the number of profile counters needed for the function. + unsigned getNumCounters() { + unsigned NumCounters = 0; + for (auto &E : this->MST.AllEdges) { + if (!E->InMST && !E->Removed) + NumCounters++; + } + return NumCounters + SIVisitor.getNumOfSelectInsts(); + } }; // Compute Hash value for the CFG: the lower 32 bits are CRC32 of the index @@ -293,13 +395,90 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() { const TerminatorInst *TI = BB.getTerminator(); for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) { BasicBlock *Succ = TI->getSuccessor(I); - uint32_t Index = getBBInfo(Succ).Index; + auto BI = findBBInfo(Succ); + if (BI == nullptr) + continue; + uint32_t Index = BI->Index; for (int J = 0; J < 4; J++) Indexes.push_back((char)(Index >> (J * 8))); } } JC.update(Indexes); - FunctionHash = (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC(); + FunctionHash = (uint64_t)SIVisitor.getNumOfSelectInsts() << 56 | + (uint64_t)IndirectCallSites.size() << 48 | + (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC(); +} + +// Check if we can safely rename this Comdat function. +static bool canRenameComdat( + Function &F, + std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers) { + if (!DoComdatRenaming || !canRenameComdatFunc(F, true)) + return false; + + // FIXME: Current only handle those Comdat groups that only containing one + // function and function aliases. + // (1) For a Comdat group containing multiple functions, we need to have a + // unique postfix based on the hashes for each function. There is a + // non-trivial code refactoring to do this efficiently. + // (2) Variables can not be renamed, so we can not rename Comdat function in a + // group including global vars. + Comdat *C = F.getComdat(); + for (auto &&CM : make_range(ComdatMembers.equal_range(C))) { + if (dyn_cast<GlobalAlias>(CM.second)) + continue; + Function *FM = dyn_cast<Function>(CM.second); + if (FM != &F) + return false; + } + return true; +} + +// Append the CFGHash to the Comdat function name. +template <class Edge, class BBInfo> +void FuncPGOInstrumentation<Edge, BBInfo>::renameComdatFunction() { + if (!canRenameComdat(F, ComdatMembers)) + return; + std::string OrigName = F.getName().str(); + std::string NewFuncName = + Twine(F.getName() + "." + Twine(FunctionHash)).str(); + F.setName(Twine(NewFuncName)); + GlobalAlias::create(GlobalValue::WeakAnyLinkage, OrigName, &F); + FuncName = Twine(FuncName + "." + Twine(FunctionHash)).str(); + Comdat *NewComdat; + Module *M = F.getParent(); + // For AvailableExternallyLinkage functions, change the linkage to + // LinkOnceODR and put them into comdat. This is because after renaming, there + // is no backup external copy available for the function. + if (!F.hasComdat()) { + assert(F.getLinkage() == GlobalValue::AvailableExternallyLinkage); + NewComdat = M->getOrInsertComdat(StringRef(NewFuncName)); + F.setLinkage(GlobalValue::LinkOnceODRLinkage); + F.setComdat(NewComdat); + return; + } + + // This function belongs to a single function Comdat group. + Comdat *OrigComdat = F.getComdat(); + std::string NewComdatName = + Twine(OrigComdat->getName() + "." + Twine(FunctionHash)).str(); + NewComdat = M->getOrInsertComdat(StringRef(NewComdatName)); + NewComdat->setSelectionKind(OrigComdat->getSelectionKind()); + + for (auto &&CM : make_range(ComdatMembers.equal_range(OrigComdat))) { + if (GlobalAlias *GA = dyn_cast<GlobalAlias>(CM.second)) { + // For aliases, change the name directly. + assert(dyn_cast<Function>(GA->getAliasee()->stripPointerCasts()) == &F); + std::string OrigGAName = GA->getName().str(); + GA->setName(Twine(GA->getName() + "." + Twine(FunctionHash))); + GlobalAlias::create(GlobalValue::WeakAnyLinkage, OrigGAName, GA); + continue; + } + // Must be a function. + Function *CF = dyn_cast<Function>(CM.second); + assert(CF); + CF->setComdat(NewComdat); + } } // Given a CFG E to be instrumented, find which BB to place the instrumented @@ -340,15 +519,12 @@ BasicBlock *FuncPGOInstrumentation<Edge, BBInfo>::getInstrBB(Edge *E) { // Visit all edge and instrument the edges not in MST, and do value profiling. // Critical edges will be split. -static void instrumentOneFunc(Function &F, Module *M, - BranchProbabilityInfo *BPI, - BlockFrequencyInfo *BFI) { - unsigned NumCounters = 0; - FuncPGOInstrumentation<PGOEdge, BBInfo> FuncInfo(F, true, BPI, BFI); - for (auto &E : FuncInfo.MST.AllEdges) { - if (!E->InMST && !E->Removed) - NumCounters++; - } +static void instrumentOneFunc( + Function &F, Module *M, BranchProbabilityInfo *BPI, BlockFrequencyInfo *BFI, + std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers) { + FuncPGOInstrumentation<PGOEdge, BBInfo> FuncInfo(F, ComdatMembers, true, BPI, + BFI); + unsigned NumCounters = FuncInfo.getNumCounters(); uint32_t I = 0; Type *I8PtrTy = Type::getInt8PtrTy(M->getContext()); @@ -367,11 +543,16 @@ static void instrumentOneFunc(Function &F, Module *M, Builder.getInt32(I++)}); } + // Now instrument select instructions: + FuncInfo.SIVisitor.instrumentSelects(F, &I, NumCounters, FuncInfo.FuncNameVar, + FuncInfo.FunctionHash); + assert(I == NumCounters); + if (DisableValueProfiling) return; unsigned NumIndirectCallSites = 0; - for (auto &I : findIndirectCallSites(F)) { + for (auto &I : FuncInfo.IndirectCallSites) { CallSite CS(I); Value *Callee = CS.getCalledValue(); DEBUG(dbgs() << "Instrument one indirect call: CallSite Index = " @@ -456,10 +637,12 @@ static uint64_t sumEdgeCount(const ArrayRef<PGOUseEdge *> Edges) { class PGOUseFunc { public: - PGOUseFunc(Function &Func, Module *Modu, BranchProbabilityInfo *BPI = nullptr, + PGOUseFunc(Function &Func, Module *Modu, + std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers, + BranchProbabilityInfo *BPI = nullptr, BlockFrequencyInfo *BFI = nullptr) - : F(Func), M(Modu), FuncInfo(Func, false, BPI, BFI), - FreqAttr(FFA_Normal) {} + : F(Func), M(Modu), FuncInfo(Func, ComdatMembers, false, BPI, BFI), + CountPosition(0), ProfileCountSize(0), FreqAttr(FFA_Normal) {} // Read counts for the instrumented BB from profile. bool readCounters(IndexedInstrProfReader *PGOReader); @@ -479,24 +662,37 @@ public: // Return the function hotness from the profile. FuncFreqAttr getFuncFreqAttr() const { return FreqAttr; } + // Return the function hash. + uint64_t getFuncHash() const { return FuncInfo.FunctionHash; } // Return the profile record for this function; InstrProfRecord &getProfileRecord() { return ProfileRecord; } + // Return the auxiliary BB information. + UseBBInfo &getBBInfo(const BasicBlock *BB) const { + return FuncInfo.getBBInfo(BB); + } + + // Return the auxiliary BB information if available. + UseBBInfo *findBBInfo(const BasicBlock *BB) const { + return FuncInfo.findBBInfo(BB); + } + private: Function &F; Module *M; // This member stores the shared information with class PGOGenFunc. FuncPGOInstrumentation<PGOUseEdge, UseBBInfo> FuncInfo; - // Return the auxiliary BB information. - UseBBInfo &getBBInfo(const BasicBlock *BB) const { - return FuncInfo.getBBInfo(BB); - } - // The maximum count value in the profile. This is only used in PGO use // compilation. uint64_t ProgramMaxCount; + // Position of counter that remains to be read. + uint32_t CountPosition; + + // Total size of the profile count for this function. + uint32_t ProfileCountSize; + // ProfileRecord for this function. InstrProfRecord ProfileRecord; @@ -535,6 +731,7 @@ private: void PGOUseFunc::setInstrumentedCounts( const std::vector<uint64_t> &CountFromProfile) { + assert(FuncInfo.getNumCounters() == CountFromProfile.size()); // Use a worklist as we will update the vector during the iteration. std::vector<PGOUseEdge *> WorkList; for (auto &E : FuncInfo.MST.AllEdges) @@ -564,6 +761,8 @@ void PGOUseFunc::setInstrumentedCounts( NewEdge1.InMST = true; getBBInfo(InstrBB).setBBInfoCount(CountValue); } + ProfileCountSize = CountFromProfile.size(); + CountPosition = I; } // Set the count value for the unknown edge. There should be one and only one @@ -594,11 +793,15 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader) { bool SkipWarning = false; if (Err == instrprof_error::unknown_function) { NumOfPGOMissing++; - SkipWarning = NoPGOWarnMissing; + SkipWarning = !PGOWarnMissing; } else if (Err == instrprof_error::hash_mismatch || Err == instrprof_error::malformed) { NumOfPGOMismatch++; - SkipWarning = NoPGOWarnMismatch; + SkipWarning = + NoPGOWarnMismatch || + (NoPGOWarnMismatchComdat && + (F.hasComdat() || + F.getLinkage() == GlobalValue::AvailableExternallyLinkage)); } if (SkipWarning) @@ -663,27 +866,38 @@ void PGOUseFunc::populateCounters() { // For efficient traversal, it's better to start from the end as most // of the instrumented edges are at the end. for (auto &BB : reverse(F)) { - UseBBInfo &Count = getBBInfo(&BB); - if (!Count.CountValid) { - if (Count.UnknownCountOutEdge == 0) { - Count.CountValue = sumEdgeCount(Count.OutEdges); - Count.CountValid = true; + UseBBInfo *Count = findBBInfo(&BB); + if (Count == nullptr) + continue; + if (!Count->CountValid) { + if (Count->UnknownCountOutEdge == 0) { + Count->CountValue = sumEdgeCount(Count->OutEdges); + Count->CountValid = true; Changes = true; - } else if (Count.UnknownCountInEdge == 0) { - Count.CountValue = sumEdgeCount(Count.InEdges); - Count.CountValid = true; + } else if (Count->UnknownCountInEdge == 0) { + Count->CountValue = sumEdgeCount(Count->InEdges); + Count->CountValid = true; Changes = true; } } - if (Count.CountValid) { - if (Count.UnknownCountOutEdge == 1) { - uint64_t Total = Count.CountValue - sumEdgeCount(Count.OutEdges); - setEdgeCount(Count.OutEdges, Total); + if (Count->CountValid) { + if (Count->UnknownCountOutEdge == 1) { + uint64_t Total = 0; + uint64_t OutSum = sumEdgeCount(Count->OutEdges); + // If the one of the successor block can early terminate (no-return), + // we can end up with situation where out edge sum count is larger as + // the source BB's count is collected by a post-dominated block. + if (Count->CountValue > OutSum) + Total = Count->CountValue - OutSum; + setEdgeCount(Count->OutEdges, Total); Changes = true; } - if (Count.UnknownCountInEdge == 1) { - uint64_t Total = Count.CountValue - sumEdgeCount(Count.InEdges); - setEdgeCount(Count.InEdges, Total); + if (Count->UnknownCountInEdge == 1) { + uint64_t Total = 0; + uint64_t InSum = sumEdgeCount(Count->InEdges); + if (Count->CountValue > InSum) + Total = Count->CountValue - InSum; + setEdgeCount(Count->InEdges, Total); Changes = true; } } @@ -693,24 +907,50 @@ void PGOUseFunc::populateCounters() { DEBUG(dbgs() << "Populate counts in " << NumPasses << " passes.\n"); #ifndef NDEBUG // Assert every BB has a valid counter. - for (auto &BB : F) - assert(getBBInfo(&BB).CountValid && "BB count is not valid"); + for (auto &BB : F) { + auto BI = findBBInfo(&BB); + if (BI == nullptr) + continue; + assert(BI->CountValid && "BB count is not valid"); + } #endif uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue; F.setEntryCount(FuncEntryCount); uint64_t FuncMaxCount = FuncEntryCount; - for (auto &BB : F) - FuncMaxCount = std::max(FuncMaxCount, getBBInfo(&BB).CountValue); + for (auto &BB : F) { + auto BI = findBBInfo(&BB); + if (BI == nullptr) + continue; + FuncMaxCount = std::max(FuncMaxCount, BI->CountValue); + } markFunctionAttributes(FuncEntryCount, FuncMaxCount); + // Now annotate select instructions + FuncInfo.SIVisitor.annotateSelects(F, this, &CountPosition); + assert(CountPosition == ProfileCountSize); + DEBUG(FuncInfo.dumpInfo("after reading profile.")); } +static void setProfMetadata(Module *M, Instruction *TI, + ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) { + MDBuilder MDB(M->getContext()); + assert(MaxCount > 0 && "Bad max count"); + uint64_t Scale = calculateCountScale(MaxCount); + SmallVector<unsigned, 4> Weights; + for (const auto &ECI : EdgeCounts) + Weights.push_back(scaleBranchCount(ECI, Scale)); + + DEBUG(dbgs() << "Weight is: "; + for (const auto &W : Weights) { dbgs() << W << " "; } + dbgs() << "\n";); + TI->setMetadata(llvm::LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); +} + // Assign the scaled count values to the BB with multiple out edges. void PGOUseFunc::setBranchWeights() { // Generate MD_prof metadata for every branch instruction. DEBUG(dbgs() << "\nSetting branch weights.\n"); - MDBuilder MDB(M->getContext()); for (auto &BB : F) { TerminatorInst *TI = BB.getTerminator(); if (TI->getNumSuccessors() < 2) @@ -723,7 +963,7 @@ void PGOUseFunc::setBranchWeights() { // We have a non-zero Branch BB. const UseBBInfo &BBCountInfo = getBBInfo(&BB); unsigned Size = BBCountInfo.OutEdges.size(); - SmallVector<unsigned, 2> EdgeCounts(Size, 0); + SmallVector<uint64_t, 2> EdgeCounts(Size, 0); uint64_t MaxCount = 0; for (unsigned s = 0; s < Size; s++) { const PGOUseEdge *E = BBCountInfo.OutEdges[s]; @@ -737,20 +977,64 @@ void PGOUseFunc::setBranchWeights() { MaxCount = EdgeCount; EdgeCounts[SuccNum] = EdgeCount; } - assert(MaxCount > 0 && "Bad max count"); - uint64_t Scale = calculateCountScale(MaxCount); - SmallVector<unsigned, 4> Weights; - for (const auto &ECI : EdgeCounts) - Weights.push_back(scaleBranchCount(ECI, Scale)); - - TI->setMetadata(llvm::LLVMContext::MD_prof, - MDB.createBranchWeights(Weights)); - DEBUG(dbgs() << "Weight is: "; - for (const auto &W : Weights) { dbgs() << W << " "; } - dbgs() << "\n";); + setProfMetadata(M, TI, EdgeCounts, MaxCount); } } +void SelectInstVisitor::instrumentOneSelectInst(SelectInst &SI) { + Module *M = F.getParent(); + IRBuilder<> Builder(&SI); + Type *Int64Ty = Builder.getInt64Ty(); + Type *I8PtrTy = Builder.getInt8PtrTy(); + auto *Step = Builder.CreateZExt(SI.getCondition(), Int64Ty); + Builder.CreateCall( + Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment_step), + {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), + Builder.getInt64(FuncHash), + Builder.getInt32(TotalNumCtrs), Builder.getInt32(*CurCtrIdx), Step}); + ++(*CurCtrIdx); +} + +void SelectInstVisitor::annotateOneSelectInst(SelectInst &SI) { + std::vector<uint64_t> &CountFromProfile = UseFunc->getProfileRecord().Counts; + assert(*CurCtrIdx < CountFromProfile.size() && + "Out of bound access of counters"); + uint64_t SCounts[2]; + SCounts[0] = CountFromProfile[*CurCtrIdx]; // True count + ++(*CurCtrIdx); + uint64_t TotalCount = 0; + auto BI = UseFunc->findBBInfo(SI.getParent()); + if (BI != nullptr) + TotalCount = BI->CountValue; + // False Count + SCounts[1] = (TotalCount > SCounts[0] ? TotalCount - SCounts[0] : 0); + uint64_t MaxCount = std::max(SCounts[0], SCounts[1]); + if (MaxCount) + setProfMetadata(F.getParent(), &SI, SCounts, MaxCount); +} + +void SelectInstVisitor::visitSelectInst(SelectInst &SI) { + if (!PGOInstrSelect) + return; + // FIXME: do not handle this yet. + if (SI.getCondition()->getType()->isVectorTy()) + return; + + NSIs++; + switch (Mode) { + case VM_counting: + return; + case VM_instrument: + instrumentOneSelectInst(SI); + return; + case VM_annotate: + annotateOneSelectInst(SI); + return; + } + + llvm_unreachable("Unknown visiting mode"); +} + // Traverse all the indirect callsites and annotate the instructions. void PGOUseFunc::annotateIndirectCallSites() { if (DisableValueProfiling) @@ -760,7 +1044,7 @@ void PGOUseFunc::annotateIndirectCallSites() { createPGOFuncNameMetadata(F, FuncInfo.FuncName); unsigned IndirectCallSiteIndex = 0; - auto IndirectCallSites = findIndirectCallSites(F); + auto &IndirectCallSites = FuncInfo.IndirectCallSites; unsigned NumValueSites = ProfileRecord.getNumValueSites(IPVK_IndirectCallTarget); if (NumValueSites != IndirectCallSites.size()) { @@ -784,7 +1068,7 @@ void PGOUseFunc::annotateIndirectCallSites() { } } // end anonymous namespace -// Create a COMDAT variable IR_LEVEL_PROF_VARNAME to make the runtime +// Create a COMDAT variable INSTR_PROF_RAW_VERSION_VAR to make the runtime // aware this is an ir_level profile so it can set the version flag. static void createIRLevelProfileFlagVariable(Module &M) { Type *IntTy64 = Type::getInt64Ty(M.getContext()); @@ -792,26 +1076,47 @@ static void createIRLevelProfileFlagVariable(Module &M) { auto IRLevelVersionVariable = new GlobalVariable( M, IntTy64, true, GlobalVariable::ExternalLinkage, Constant::getIntegerValue(IntTy64, APInt(64, ProfileVersion)), - INSTR_PROF_QUOTE(IR_LEVEL_PROF_VERSION_VAR)); + INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR)); IRLevelVersionVariable->setVisibility(GlobalValue::DefaultVisibility); Triple TT(M.getTargetTriple()); if (!TT.supportsCOMDAT()) IRLevelVersionVariable->setLinkage(GlobalValue::WeakAnyLinkage); else IRLevelVersionVariable->setComdat(M.getOrInsertComdat( - StringRef(INSTR_PROF_QUOTE(IR_LEVEL_PROF_VERSION_VAR)))); + StringRef(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR)))); +} + +// Collect the set of members for each Comdat in module M and store +// in ComdatMembers. +static void collectComdatMembers( + Module &M, + std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers) { + if (!DoComdatRenaming) + return; + for (Function &F : M) + if (Comdat *C = F.getComdat()) + ComdatMembers.insert(std::make_pair(C, &F)); + for (GlobalVariable &GV : M.globals()) + if (Comdat *C = GV.getComdat()) + ComdatMembers.insert(std::make_pair(C, &GV)); + for (GlobalAlias &GA : M.aliases()) + if (Comdat *C = GA.getComdat()) + ComdatMembers.insert(std::make_pair(C, &GA)); } static bool InstrumentAllFunctions( Module &M, function_ref<BranchProbabilityInfo *(Function &)> LookupBPI, function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) { createIRLevelProfileFlagVariable(M); + std::unordered_multimap<Comdat *, GlobalValue *> ComdatMembers; + collectComdatMembers(M, ComdatMembers); + for (auto &F : M) { if (F.isDeclaration()) continue; auto *BPI = LookupBPI(F); auto *BFI = LookupBFI(F); - instrumentOneFunc(F, &M, BPI, BFI); + instrumentOneFunc(F, &M, BPI, BFI, ComdatMembers); } return true; } @@ -830,7 +1135,7 @@ bool PGOInstrumentationGenLegacyPass::runOnModule(Module &M) { } PreservedAnalyses PGOInstrumentationGen::run(Module &M, - AnalysisManager<Module> &AM) { + ModuleAnalysisManager &AM) { auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); auto LookupBPI = [&FAM](Function &F) { @@ -877,6 +1182,8 @@ static bool annotateAllFunctions( return false; } + std::unordered_multimap<Comdat *, GlobalValue *> ComdatMembers; + collectComdatMembers(M, ComdatMembers); std::vector<Function *> HotFunctions; std::vector<Function *> ColdFunctions; for (auto &F : M) { @@ -884,7 +1191,7 @@ static bool annotateAllFunctions( continue; auto *BPI = LookupBPI(F); auto *BFI = LookupBFI(F); - PGOUseFunc Func(F, &M, BPI, BFI); + PGOUseFunc Func(F, &M, ComdatMembers, BPI, BFI); if (!Func.readCounters(PGOReader.get())) continue; Func.populateCounters(); @@ -910,7 +1217,6 @@ static bool annotateAllFunctions( F->addFnAttr(llvm::Attribute::Cold); DEBUG(dbgs() << "Set cold attribute to function: " << F->getName() << "\n"); } - return true; } @@ -921,7 +1227,7 @@ PGOInstrumentationUse::PGOInstrumentationUse(std::string Filename) } PreservedAnalyses PGOInstrumentationUse::run(Module &M, - AnalysisManager<Module> &AM) { + ModuleAnalysisManager &AM) { auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); auto LookupBPI = [&FAM](Function &F) { |