diff options
Diffstat (limited to 'contrib/llvm/lib/Transforms/IPO/MergeFunctions.cpp')
-rw-r--r-- | contrib/llvm/lib/Transforms/IPO/MergeFunctions.cpp | 560 |
1 files changed, 453 insertions, 107 deletions
diff --git a/contrib/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/contrib/llvm/lib/Transforms/IPO/MergeFunctions.cpp index 2e3519e..8a209a1 100644 --- a/contrib/llvm/lib/Transforms/IPO/MergeFunctions.cpp +++ b/contrib/llvm/lib/Transforms/IPO/MergeFunctions.cpp @@ -27,6 +27,14 @@ // -- We define Function* container class with custom "operator<" (FunctionPtr). // -- "FunctionPtr" instances are stored in std::set collection, so every // std::set::insert operation will give you result in log(N) time. +// +// As an optimization, a hash of the function structure is calculated first, and +// two functions are only compared if they have the same hash. This hash is +// cheap to compute, and has the property that if function F == G according to +// the comparison function, then hash(F) == hash(G). This consistency property +// is critical to ensuring all possible merging opportunities are exploited. +// Collisions in the hash affect the speed of the pass but not the correctness +// or determinism of the resulting transformation. // // When a match is found the functions are folded. If both functions are // overridable, we move the functionality into a new internal function and @@ -87,6 +95,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/Hashing.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -97,12 +106,14 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/IR/ValueMap.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include <vector> + using namespace llvm; #define DEBUG_TYPE "mergefunc" @@ -121,21 +132,64 @@ static cl::opt<unsigned> NumFunctionsForSanityCheck( namespace { +/// GlobalNumberState assigns an integer to each global value in the program, +/// which is used by the comparison routine to order references to globals. This +/// state must be preserved throughout the pass, because Functions and other +/// globals need to maintain their relative order. Globals are assigned a number +/// when they are first visited. This order is deterministic, and so the +/// assigned numbers are as well. When two functions are merged, neither number +/// is updated. If the symbols are weak, this would be incorrect. If they are +/// strong, then one will be replaced at all references to the other, and so +/// direct callsites will now see one or the other symbol, and no update is +/// necessary. Note that if we were guaranteed unique names, we could just +/// compare those, but this would not work for stripped bitcodes or for those +/// few symbols without a name. +class GlobalNumberState { + struct Config : ValueMapConfig<GlobalValue*> { + enum { FollowRAUW = false }; + }; + // Each GlobalValue is mapped to an identifier. The Config ensures when RAUW + // occurs, the mapping does not change. Tracking changes is unnecessary, and + // also problematic for weak symbols (which may be overwritten). + typedef ValueMap<GlobalValue *, uint64_t, Config> ValueNumberMap; + ValueNumberMap GlobalNumbers; + // The next unused serial number to assign to a global. + uint64_t NextNumber; + public: + GlobalNumberState() : GlobalNumbers(), NextNumber(0) {} + uint64_t getNumber(GlobalValue* Global) { + ValueNumberMap::iterator MapIter; + bool Inserted; + std::tie(MapIter, Inserted) = GlobalNumbers.insert({Global, NextNumber}); + if (Inserted) + NextNumber++; + return MapIter->second; + } + void clear() { + GlobalNumbers.clear(); + } +}; + /// FunctionComparator - Compares two functions to determine whether or not /// they will generate machine code with the same behaviour. DataLayout is /// used if available. The comparator always fails conservatively (erring on the /// side of claiming that two functions are different). class FunctionComparator { public: - FunctionComparator(const Function *F1, const Function *F2) - : FnL(F1), FnR(F2) {} + FunctionComparator(const Function *F1, const Function *F2, + GlobalNumberState* GN) + : FnL(F1), FnR(F2), GlobalNumbers(GN) {} /// Test whether the two functions have equivalent behaviour. int compare(); + /// Hash a function. Equivalent functions will have the same hash, and unequal + /// functions will have different hashes with high probability. + typedef uint64_t FunctionHash; + static FunctionHash functionHash(Function &); private: /// Test whether two basic blocks have equivalent behaviour. - int compare(const BasicBlock *BBL, const BasicBlock *BBR); + int cmpBasicBlocks(const BasicBlock *BBL, const BasicBlock *BBR); /// Constants comparison. /// Its analog to lexicographical comparison between hypothetical numbers @@ -241,6 +295,10 @@ private: /// If these properties are equal - compare their contents. int cmpConstants(const Constant *L, const Constant *R); + /// Compares two global values by number. Uses the GlobalNumbersState to + /// identify the same gobals across function calls. + int cmpGlobalValues(GlobalValue *L, GlobalValue *R); + /// Assign or look up previously assigned numbers for the two values, and /// return whether the numbers are equal. Numbers are assigned in the order /// visited. @@ -320,8 +378,9 @@ private: /// /// 1. If types are of different kind (different type IDs). /// Return result of type IDs comparison, treating them as numbers. - /// 2. If types are vectors or integers, compare Type* values as numbers. - /// 3. Types has same ID, so check whether they belongs to the next group: + /// 2. If types are integers, check that they have the same width. If they + /// are vectors, check that they have the same count and subtype. + /// 3. Types have the same ID, so check whether they are one of: /// * Void /// * Float /// * Double @@ -330,8 +389,7 @@ private: /// * PPC_FP128 /// * Label /// * Metadata - /// If so - return 0, yes - we can treat these types as equal only because - /// their IDs are same. + /// We can treat these types as equal whenever their IDs are same. /// 4. If Left and Right are pointers, return result of address space /// comparison (numbers comparison). We can treat pointer types of same /// address space as equal. @@ -343,11 +401,13 @@ private: int cmpTypes(Type *TyL, Type *TyR) const; int cmpNumbers(uint64_t L, uint64_t R) const; - int cmpAPInts(const APInt &L, const APInt &R) const; int cmpAPFloats(const APFloat &L, const APFloat &R) const; - int cmpStrings(StringRef L, StringRef R) const; + int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const; + int cmpMem(StringRef L, StringRef R) const; int cmpAttrs(const AttributeSet L, const AttributeSet R) const; + int cmpRangeMetadata(const MDNode* L, const MDNode* R) const; + int cmpOperandBundlesSchema(const Instruction *L, const Instruction *R) const; // The two functions undergoing comparison. const Function *FnL, *FnR; @@ -386,30 +446,30 @@ private: /// could be operands from further BBs we didn't scan yet. /// So it's impossible to use dominance properties in general. DenseMap<const Value*, int> sn_mapL, sn_mapR; + + // The global state we will use + GlobalNumberState* GlobalNumbers; }; class FunctionNode { mutable AssertingVH<Function> F; - + FunctionComparator::FunctionHash Hash; public: - FunctionNode(Function *F) : F(F) {} + // Note the hash is recalculated potentially multiple times, but it is cheap. + FunctionNode(Function *F) + : F(F), Hash(FunctionComparator::functionHash(*F)) {} Function *getFunc() const { return F; } + FunctionComparator::FunctionHash getHash() const { return Hash; } /// Replace the reference to the function F by the function G, assuming their /// implementations are equal. void replaceBy(Function *G) const { - assert(!(*this < FunctionNode(G)) && !(FunctionNode(G) < *this) && - "The two functions must be equal"); - F = G; } - void release() { F = 0; } - bool operator<(const FunctionNode &RHS) const { - return (FunctionComparator(F, RHS.getFunc()).compare()) == -1; - } + void release() { F = nullptr; } }; -} +} // end anonymous namespace int FunctionComparator::cmpNumbers(uint64_t L, uint64_t R) const { if (L < R) return -1; @@ -426,13 +486,25 @@ int FunctionComparator::cmpAPInts(const APInt &L, const APInt &R) const { } int FunctionComparator::cmpAPFloats(const APFloat &L, const APFloat &R) const { - if (int Res = cmpNumbers((uint64_t)&L.getSemantics(), - (uint64_t)&R.getSemantics())) + // Floats are ordered first by semantics (i.e. float, double, half, etc.), + // then by value interpreted as a bitstring (aka APInt). + const fltSemantics &SL = L.getSemantics(), &SR = R.getSemantics(); + if (int Res = cmpNumbers(APFloat::semanticsPrecision(SL), + APFloat::semanticsPrecision(SR))) + return Res; + if (int Res = cmpNumbers(APFloat::semanticsMaxExponent(SL), + APFloat::semanticsMaxExponent(SR))) + return Res; + if (int Res = cmpNumbers(APFloat::semanticsMinExponent(SL), + APFloat::semanticsMinExponent(SR))) + return Res; + if (int Res = cmpNumbers(APFloat::semanticsSizeInBits(SL), + APFloat::semanticsSizeInBits(SR))) return Res; return cmpAPInts(L.bitcastToAPInt(), R.bitcastToAPInt()); } -int FunctionComparator::cmpStrings(StringRef L, StringRef R) const { +int FunctionComparator::cmpMem(StringRef L, StringRef R) const { // Prevent heavy comparison, compare sizes first. if (int Res = cmpNumbers(L.size(), R.size())) return Res; @@ -466,6 +538,59 @@ int FunctionComparator::cmpAttrs(const AttributeSet L, return 0; } +int FunctionComparator::cmpRangeMetadata(const MDNode* L, + const MDNode* R) const { + if (L == R) + return 0; + if (!L) + return -1; + if (!R) + return 1; + // Range metadata is a sequence of numbers. Make sure they are the same + // sequence. + // TODO: Note that as this is metadata, it is possible to drop and/or merge + // this data when considering functions to merge. Thus this comparison would + // return 0 (i.e. equivalent), but merging would become more complicated + // because the ranges would need to be unioned. It is not likely that + // functions differ ONLY in this metadata if they are actually the same + // function semantically. + if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands())) + return Res; + for (size_t I = 0; I < L->getNumOperands(); ++I) { + ConstantInt* LLow = mdconst::extract<ConstantInt>(L->getOperand(I)); + ConstantInt* RLow = mdconst::extract<ConstantInt>(R->getOperand(I)); + if (int Res = cmpAPInts(LLow->getValue(), RLow->getValue())) + return Res; + } + return 0; +} + +int FunctionComparator::cmpOperandBundlesSchema(const Instruction *L, + const Instruction *R) const { + ImmutableCallSite LCS(L); + ImmutableCallSite RCS(R); + + assert(LCS && RCS && "Must be calls or invokes!"); + assert(LCS.isCall() == RCS.isCall() && "Can't compare otherwise!"); + + if (int Res = + cmpNumbers(LCS.getNumOperandBundles(), RCS.getNumOperandBundles())) + return Res; + + for (unsigned i = 0, e = LCS.getNumOperandBundles(); i != e; ++i) { + auto OBL = LCS.getOperandBundleAt(i); + auto OBR = RCS.getOperandBundleAt(i); + + if (int Res = OBL.getTagName().compare(OBR.getTagName())) + return Res; + + if (int Res = cmpNumbers(OBL.Inputs.size(), OBR.Inputs.size())) + return Res; + } + + return 0; +} + /// Constants comparison: /// 1. Check whether type of L constant could be losslessly bitcasted to R /// type. @@ -500,9 +625,9 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) { unsigned TyLWidth = 0; unsigned TyRWidth = 0; - if (const VectorType *VecTyL = dyn_cast<VectorType>(TyL)) + if (auto *VecTyL = dyn_cast<VectorType>(TyL)) TyLWidth = VecTyL->getBitWidth(); - if (const VectorType *VecTyR = dyn_cast<VectorType>(TyR)) + if (auto *VecTyR = dyn_cast<VectorType>(TyR)) TyRWidth = VecTyR->getBitWidth(); if (TyLWidth != TyRWidth) @@ -538,11 +663,29 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) { if (!L->isNullValue() && R->isNullValue()) return -1; + auto GlobalValueL = const_cast<GlobalValue*>(dyn_cast<GlobalValue>(L)); + auto GlobalValueR = const_cast<GlobalValue*>(dyn_cast<GlobalValue>(R)); + if (GlobalValueL && GlobalValueR) { + return cmpGlobalValues(GlobalValueL, GlobalValueR); + } + if (int Res = cmpNumbers(L->getValueID(), R->getValueID())) return Res; + if (const auto *SeqL = dyn_cast<ConstantDataSequential>(L)) { + const auto *SeqR = cast<ConstantDataSequential>(R); + // This handles ConstantDataArray and ConstantDataVector. Note that we + // compare the two raw data arrays, which might differ depending on the host + // endianness. This isn't a problem though, because the endiness of a module + // will affect the order of the constants, but this order is the same + // for a given input module and host platform. + return cmpMem(SeqL->getRawDataValues(), SeqR->getRawDataValues()); + } + switch (L->getValueID()) { - case Value::UndefValueVal: return TypesRes; + case Value::UndefValueVal: + case Value::ConstantTokenNoneVal: + return TypesRes; case Value::ConstantIntVal: { const APInt &LInt = cast<ConstantInt>(L)->getValue(); const APInt &RInt = cast<ConstantInt>(R)->getValue(); @@ -609,19 +752,55 @@ int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) { } return 0; } - case Value::FunctionVal: - case Value::GlobalVariableVal: - case Value::GlobalAliasVal: - default: // Unknown constant, cast L and R pointers to numbers and compare. - return cmpNumbers((uint64_t)L, (uint64_t)R); + case Value::BlockAddressVal: { + const BlockAddress *LBA = cast<BlockAddress>(L); + const BlockAddress *RBA = cast<BlockAddress>(R); + if (int Res = cmpValues(LBA->getFunction(), RBA->getFunction())) + return Res; + if (LBA->getFunction() == RBA->getFunction()) { + // They are BBs in the same function. Order by which comes first in the + // BB order of the function. This order is deterministic. + Function* F = LBA->getFunction(); + BasicBlock *LBB = LBA->getBasicBlock(); + BasicBlock *RBB = RBA->getBasicBlock(); + if (LBB == RBB) + return 0; + for(BasicBlock &BB : F->getBasicBlockList()) { + if (&BB == LBB) { + assert(&BB != RBB); + return -1; + } + if (&BB == RBB) + return 1; + } + llvm_unreachable("Basic Block Address does not point to a basic block in " + "its function."); + return -1; + } else { + // cmpValues said the functions are the same. So because they aren't + // literally the same pointer, they must respectively be the left and + // right functions. + assert(LBA->getFunction() == FnL && RBA->getFunction() == FnR); + // cmpValues will tell us if these are equivalent BasicBlocks, in the + // context of their respective functions. + return cmpValues(LBA->getBasicBlock(), RBA->getBasicBlock()); + } } + default: // Unknown constant, abort. + DEBUG(dbgs() << "Looking at valueID " << L->getValueID() << "\n"); + llvm_unreachable("Constant ValueID not recognized."); + return -1; + } +} + +int FunctionComparator::cmpGlobalValues(GlobalValue *L, GlobalValue* R) { + return cmpNumbers(GlobalNumbers->getNumber(L), GlobalNumbers->getNumber(R)); } /// cmpType - compares two types, /// defines total ordering among the types set. /// See method declaration comments for more details. int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const { - PointerType *PTyL = dyn_cast<PointerType>(TyL); PointerType *PTyR = dyn_cast<PointerType>(TyR); @@ -642,10 +821,15 @@ int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const { llvm_unreachable("Unknown type!"); // Fall through in Release mode. case Type::IntegerTyID: - case Type::VectorTyID: - // TyL == TyR would have returned true earlier. - return cmpNumbers((uint64_t)TyL, (uint64_t)TyR); - + return cmpNumbers(cast<IntegerType>(TyL)->getBitWidth(), + cast<IntegerType>(TyR)->getBitWidth()); + case Type::VectorTyID: { + VectorType *VTyL = cast<VectorType>(TyL), *VTyR = cast<VectorType>(TyR); + if (int Res = cmpNumbers(VTyL->getNumElements(), VTyR->getNumElements())) + return Res; + return cmpTypes(VTyL->getElementType(), VTyR->getElementType()); + } + // TyL == TyR would have returned true earlier, because types are uniqued. case Type::VoidTyID: case Type::FloatTyID: case Type::DoubleTyID: @@ -654,6 +838,7 @@ int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const { case Type::PPC_FP128TyID: case Type::LabelTyID: case Type::MetadataTyID: + case Type::TokenTyID: return 0; case Type::PointerTyID: { @@ -759,8 +944,8 @@ int FunctionComparator::cmpOperations(const Instruction *L, if (int Res = cmpNumbers(LI->getSynchScope(), cast<LoadInst>(R)->getSynchScope())) return Res; - return cmpNumbers((uint64_t)LI->getMetadata(LLVMContext::MD_range), - (uint64_t)cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range)); + return cmpRangeMetadata(LI->getMetadata(LLVMContext::MD_range), + cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range)); } if (const StoreInst *SI = dyn_cast<StoreInst>(L)) { if (int Res = @@ -783,20 +968,24 @@ int FunctionComparator::cmpOperations(const Instruction *L, if (int Res = cmpAttrs(CI->getAttributes(), cast<CallInst>(R)->getAttributes())) return Res; - return cmpNumbers( - (uint64_t)CI->getMetadata(LLVMContext::MD_range), - (uint64_t)cast<CallInst>(R)->getMetadata(LLVMContext::MD_range)); + if (int Res = cmpOperandBundlesSchema(CI, R)) + return Res; + return cmpRangeMetadata( + CI->getMetadata(LLVMContext::MD_range), + cast<CallInst>(R)->getMetadata(LLVMContext::MD_range)); } - if (const InvokeInst *CI = dyn_cast<InvokeInst>(L)) { - if (int Res = cmpNumbers(CI->getCallingConv(), + if (const InvokeInst *II = dyn_cast<InvokeInst>(L)) { + if (int Res = cmpNumbers(II->getCallingConv(), cast<InvokeInst>(R)->getCallingConv())) return Res; if (int Res = - cmpAttrs(CI->getAttributes(), cast<InvokeInst>(R)->getAttributes())) + cmpAttrs(II->getAttributes(), cast<InvokeInst>(R)->getAttributes())) + return Res; + if (int Res = cmpOperandBundlesSchema(II, R)) return Res; - return cmpNumbers( - (uint64_t)CI->getMetadata(LLVMContext::MD_range), - (uint64_t)cast<InvokeInst>(R)->getMetadata(LLVMContext::MD_range)); + return cmpRangeMetadata( + II->getMetadata(LLVMContext::MD_range), + cast<InvokeInst>(R)->getMetadata(LLVMContext::MD_range)); } if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(L)) { ArrayRef<unsigned> LIndices = IVI->getIndices(); @@ -876,9 +1065,8 @@ int FunctionComparator::cmpGEPs(const GEPOperator *GEPL, if (GEPL->accumulateConstantOffset(DL, OffsetL) && GEPR->accumulateConstantOffset(DL, OffsetR)) return cmpAPInts(OffsetL, OffsetR); - - if (int Res = cmpNumbers((uint64_t)GEPL->getPointerOperand()->getType(), - (uint64_t)GEPR->getPointerOperand()->getType())) + if (int Res = cmpTypes(GEPL->getSourceElementType(), + GEPR->getSourceElementType())) return Res; if (int Res = cmpNumbers(GEPL->getNumOperands(), GEPR->getNumOperands())) @@ -892,6 +1080,28 @@ int FunctionComparator::cmpGEPs(const GEPOperator *GEPL, return 0; } +int FunctionComparator::cmpInlineAsm(const InlineAsm *L, + const InlineAsm *R) const { + // InlineAsm's are uniqued. If they are the same pointer, obviously they are + // the same, otherwise compare the fields. + if (L == R) + return 0; + if (int Res = cmpTypes(L->getFunctionType(), R->getFunctionType())) + return Res; + if (int Res = cmpMem(L->getAsmString(), R->getAsmString())) + return Res; + if (int Res = cmpMem(L->getConstraintString(), R->getConstraintString())) + return Res; + if (int Res = cmpNumbers(L->hasSideEffects(), R->hasSideEffects())) + return Res; + if (int Res = cmpNumbers(L->isAlignStack(), R->isAlignStack())) + return Res; + if (int Res = cmpNumbers(L->getDialect(), R->getDialect())) + return Res; + llvm_unreachable("InlineAsm blocks were not uniqued."); + return 0; +} + /// Compare two values used by the two functions under pair-wise comparison. If /// this is the first time the values are seen, they're added to the mapping so /// that we will detect mismatches on next use. @@ -926,7 +1136,7 @@ int FunctionComparator::cmpValues(const Value *L, const Value *R) { const InlineAsm *InlineAsmR = dyn_cast<InlineAsm>(R); if (InlineAsmL && InlineAsmR) - return cmpNumbers((uint64_t)L, (uint64_t)R); + return cmpInlineAsm(InlineAsmL, InlineAsmR); if (InlineAsmL) return 1; if (InlineAsmR) @@ -938,12 +1148,13 @@ int FunctionComparator::cmpValues(const Value *L, const Value *R) { return cmpNumbers(LeftSN.first->second, RightSN.first->second); } // Test whether two basic blocks have equivalent behaviour. -int FunctionComparator::compare(const BasicBlock *BBL, const BasicBlock *BBR) { +int FunctionComparator::cmpBasicBlocks(const BasicBlock *BBL, + const BasicBlock *BBR) { BasicBlock::const_iterator InstL = BBL->begin(), InstLE = BBL->end(); BasicBlock::const_iterator InstR = BBR->begin(), InstRE = BBR->end(); do { - if (int Res = cmpValues(InstL, InstR)) + if (int Res = cmpValues(&*InstL, &*InstR)) return Res; const GetElementPtrInst *GEPL = dyn_cast<GetElementPtrInst>(InstL); @@ -961,7 +1172,7 @@ int FunctionComparator::compare(const BasicBlock *BBL, const BasicBlock *BBR) { if (int Res = cmpGEPs(GEPL, GEPR)) return Res; } else { - if (int Res = cmpOperations(InstL, InstR)) + if (int Res = cmpOperations(&*InstL, &*InstR)) return Res; assert(InstL->getNumOperands() == InstR->getNumOperands()); @@ -970,11 +1181,8 @@ int FunctionComparator::compare(const BasicBlock *BBL, const BasicBlock *BBR) { Value *OpR = InstR->getOperand(i); if (int Res = cmpValues(OpL, OpR)) return Res; - if (int Res = cmpNumbers(OpL->getValueID(), OpR->getValueID())) - return Res; - // TODO: Already checked in cmpOperation - if (int Res = cmpTypes(OpL->getType(), OpR->getType())) - return Res; + // cmpValues should ensure this is true. + assert(cmpTypes(OpL->getType(), OpR->getType()) == 0); } } @@ -990,7 +1198,6 @@ int FunctionComparator::compare(const BasicBlock *BBL, const BasicBlock *BBR) { // Test whether the two functions have equivalent behaviour. int FunctionComparator::compare() { - sn_mapL.clear(); sn_mapR.clear(); @@ -1001,7 +1208,7 @@ int FunctionComparator::compare() { return Res; if (FnL->hasGC()) { - if (int Res = cmpNumbers((uint64_t)FnL->getGC(), (uint64_t)FnR->getGC())) + if (int Res = cmpMem(FnL->getGC(), FnR->getGC())) return Res; } @@ -1009,7 +1216,7 @@ int FunctionComparator::compare() { return Res; if (FnL->hasSection()) { - if (int Res = cmpStrings(FnL->getSection(), FnR->getSection())) + if (int Res = cmpMem(FnL->getSection(), FnR->getSection())) return Res; } @@ -1033,7 +1240,7 @@ int FunctionComparator::compare() { ArgRI = FnR->arg_begin(), ArgLE = FnL->arg_end(); ArgLI != ArgLE; ++ArgLI, ++ArgRI) { - if (cmpValues(ArgLI, ArgRI) != 0) + if (cmpValues(&*ArgLI, &*ArgRI) != 0) llvm_unreachable("Arguments repeat!"); } @@ -1055,7 +1262,7 @@ int FunctionComparator::compare() { if (int Res = cmpValues(BBL, BBR)) return Res; - if (int Res = compare(BBL, BBR)) + if (int Res = cmpBasicBlocks(BBL, BBR)) return Res; const TerminatorInst *TermL = BBL->getTerminator(); @@ -1074,6 +1281,68 @@ int FunctionComparator::compare() { } namespace { +// Accumulate the hash of a sequence of 64-bit integers. This is similar to a +// hash of a sequence of 64bit ints, but the entire input does not need to be +// available at once. This interface is necessary for functionHash because it +// needs to accumulate the hash as the structure of the function is traversed +// without saving these values to an intermediate buffer. This form of hashing +// is not often needed, as usually the object to hash is just read from a +// buffer. +class HashAccumulator64 { + uint64_t Hash; +public: + // Initialize to random constant, so the state isn't zero. + HashAccumulator64() { Hash = 0x6acaa36bef8325c5ULL; } + void add(uint64_t V) { + Hash = llvm::hashing::detail::hash_16_bytes(Hash, V); + } + // No finishing is required, because the entire hash value is used. + uint64_t getHash() { return Hash; } +}; +} // end anonymous namespace + +// A function hash is calculated by considering only the number of arguments and +// whether a function is varargs, the order of basic blocks (given by the +// successors of each basic block in depth first order), and the order of +// opcodes of each instruction within each of these basic blocks. This mirrors +// the strategy compare() uses to compare functions by walking the BBs in depth +// first order and comparing each instruction in sequence. Because this hash +// does not look at the operands, it is insensitive to things such as the +// target of calls and the constants used in the function, which makes it useful +// when possibly merging functions which are the same modulo constants and call +// targets. +FunctionComparator::FunctionHash FunctionComparator::functionHash(Function &F) { + HashAccumulator64 H; + H.add(F.isVarArg()); + H.add(F.arg_size()); + + SmallVector<const BasicBlock *, 8> BBs; + SmallSet<const BasicBlock *, 16> VisitedBBs; + + // Walk the blocks in the same order as FunctionComparator::cmpBasicBlocks(), + // accumulating the hash of the function "structure." (BB and opcode sequence) + BBs.push_back(&F.getEntryBlock()); + VisitedBBs.insert(BBs[0]); + while (!BBs.empty()) { + const BasicBlock *BB = BBs.pop_back_val(); + // This random value acts as a block header, as otherwise the partition of + // opcodes into BBs wouldn't affect the hash, only the order of the opcodes + H.add(45798); + for (auto &Inst : *BB) { + H.add(Inst.getOpcode()); + } + const TerminatorInst *Term = BB->getTerminator(); + for (unsigned i = 0, e = Term->getNumSuccessors(); i != e; ++i) { + if (!VisitedBBs.insert(Term->getSuccessor(i)).second) + continue; + BBs.push_back(Term->getSuccessor(i)); + } + } + return H.getHash(); +} + + +namespace { /// MergeFunctions finds functions which will generate identical machine code, /// by considering all pointer types to be equivalent. Once identified, @@ -1084,14 +1353,31 @@ class MergeFunctions : public ModulePass { public: static char ID; MergeFunctions() - : ModulePass(ID), HasGlobalAliases(false) { + : ModulePass(ID), FnTree(FunctionNodeCmp(&GlobalNumbers)), FNodesInTree(), + HasGlobalAliases(false) { initializeMergeFunctionsPass(*PassRegistry::getPassRegistry()); } bool runOnModule(Module &M) override; private: - typedef std::set<FunctionNode> FnTreeType; + // The function comparison operator is provided here so that FunctionNodes do + // not need to become larger with another pointer. + class FunctionNodeCmp { + GlobalNumberState* GlobalNumbers; + public: + FunctionNodeCmp(GlobalNumberState* GN) : GlobalNumbers(GN) {} + bool operator()(const FunctionNode &LHS, const FunctionNode &RHS) const { + // Order first by hashes, then full function comparison. + if (LHS.getHash() != RHS.getHash()) + return LHS.getHash() < RHS.getHash(); + FunctionComparator FCmp(LHS.getFunc(), RHS.getFunc(), GlobalNumbers); + return FCmp.compare() == -1; + } + }; + typedef std::set<FunctionNode, FunctionNodeCmp> FnTreeType; + + GlobalNumberState GlobalNumbers; /// A work queue of functions that may have been modified and should be /// analyzed again. @@ -1133,17 +1419,23 @@ private: void writeAlias(Function *F, Function *G); /// Replace function F with function G in the function tree. - void replaceFunctionInTree(FnTreeType::iterator &IterToF, Function *G); + void replaceFunctionInTree(const FunctionNode &FN, Function *G); /// The set of all distinct functions. Use the insert() and remove() methods - /// to modify it. + /// to modify it. The map allows efficient lookup and deferring of Functions. FnTreeType FnTree; + // Map functions to the iterators of the FunctionNode which contains them + // in the FnTree. This must be updated carefully whenever the FnTree is + // modified, i.e. in insert(), remove(), and replaceFunctionInTree(), to avoid + // dangling iterators into FnTree. The invariant that preserves this is that + // there is exactly one mapping F -> FN for each FunctionNode FN in FnTree. + ValueMap<Function*, FnTreeType::iterator> FNodesInTree; /// Whether or not the target supports global aliases. bool HasGlobalAliases; }; -} // end anonymous namespace +} // end anonymous namespace char MergeFunctions::ID = 0; INITIALIZE_PASS(MergeFunctions, "mergefunc", "Merge Functions", false, false) @@ -1166,8 +1458,8 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) { for (std::vector<WeakVH>::iterator J = I; J != E && j < Max; ++J, ++j) { Function *F1 = cast<Function>(*I); Function *F2 = cast<Function>(*J); - int Res1 = FunctionComparator(F1, F2).compare(); - int Res2 = FunctionComparator(F2, F1).compare(); + int Res1 = FunctionComparator(F1, F2, &GlobalNumbers).compare(); + int Res2 = FunctionComparator(F2, F1, &GlobalNumbers).compare(); // If F1 <= F2, then F2 >= F1, otherwise report failure. if (Res1 != -Res2) { @@ -1188,8 +1480,8 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) { continue; Function *F3 = cast<Function>(*K); - int Res3 = FunctionComparator(F1, F3).compare(); - int Res4 = FunctionComparator(F2, F3).compare(); + int Res3 = FunctionComparator(F1, F3, &GlobalNumbers).compare(); + int Res4 = FunctionComparator(F2, F3, &GlobalNumbers).compare(); bool Transitive = true; @@ -1227,11 +1519,33 @@ bool MergeFunctions::doSanityCheck(std::vector<WeakVH> &Worklist) { bool MergeFunctions::runOnModule(Module &M) { bool Changed = false; - for (Module::iterator I = M.begin(), E = M.end(); I != E; ++I) { - if (!I->isDeclaration() && !I->hasAvailableExternallyLinkage()) - Deferred.push_back(WeakVH(I)); + // All functions in the module, ordered by hash. Functions with a unique + // hash value are easily eliminated. + std::vector<std::pair<FunctionComparator::FunctionHash, Function *>> + HashedFuncs; + for (Function &Func : M) { + if (!Func.isDeclaration() && !Func.hasAvailableExternallyLinkage()) { + HashedFuncs.push_back({FunctionComparator::functionHash(Func), &Func}); + } } + std::stable_sort( + HashedFuncs.begin(), HashedFuncs.end(), + [](const std::pair<FunctionComparator::FunctionHash, Function *> &a, + const std::pair<FunctionComparator::FunctionHash, Function *> &b) { + return a.first < b.first; + }); + + auto S = HashedFuncs.begin(); + for (auto I = HashedFuncs.begin(), IE = HashedFuncs.end(); I != IE; ++I) { + // If the hash value matches the previous value or the next one, we must + // consider merging it. Otherwise it is dropped and never considered again. + if ((I != S && std::prev(I)->first == I->first) || + (std::next(I) != IE && std::next(I)->first == I->first) ) { + Deferred.push_back(WeakVH(I->second)); + } + } + do { std::vector<WeakVH> Worklist; Deferred.swap(Worklist); @@ -1270,6 +1584,7 @@ bool MergeFunctions::runOnModule(Module &M) { } while (!Deferred.empty()); FnTree.clear(); + GlobalNumbers.clear(); return Changed; } @@ -1282,6 +1597,32 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) { ++UI; CallSite CS(U->getUser()); if (CS && CS.isCallee(U)) { + // Transfer the called function's attributes to the call site. Due to the + // bitcast we will 'lose' ABI changing attributes because the 'called + // function' is no longer a Function* but the bitcast. Code that looks up + // the attributes from the called function will fail. + + // FIXME: This is not actually true, at least not anymore. The callsite + // will always have the same ABI affecting attributes as the callee, + // because otherwise the original input has UB. Note that Old and New + // always have matching ABI, so no attributes need to be changed. + // Transferring other attributes may help other optimizations, but that + // should be done uniformly and not in this ad-hoc way. + auto &Context = New->getContext(); + auto NewFuncAttrs = New->getAttributes(); + auto CallSiteAttrs = CS.getAttributes(); + + CallSiteAttrs = CallSiteAttrs.addAttributes( + Context, AttributeSet::ReturnIndex, NewFuncAttrs.getRetAttributes()); + + for (unsigned argIdx = 0; argIdx < CS.arg_size(); argIdx++) { + AttributeSet Attrs = NewFuncAttrs.getParamAttributes(argIdx); + if (Attrs.getNumSlots()) + CallSiteAttrs = CallSiteAttrs.addAttributes(Context, argIdx, Attrs); + } + + CS.setAttributes(CallSiteAttrs); + remove(CS.getInstruction()->getParent()->getParent()); U->set(BitcastNew); } @@ -1352,15 +1693,15 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { SmallVector<Value *, 16> Args; unsigned i = 0; FunctionType *FFTy = F->getFunctionType(); - for (Function::arg_iterator AI = NewG->arg_begin(), AE = NewG->arg_end(); - AI != AE; ++AI) { - Args.push_back(createCast(Builder, (Value*)AI, FFTy->getParamType(i))); + for (Argument & AI : NewG->args()) { + Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i))); ++i; } CallInst *CI = Builder.CreateCall(F, Args); CI->setTailCall(); CI->setCallingConv(F->getCallingConv()); + CI->setAttributes(F->getAttributes()); if (NewG->getReturnType()->isVoidTy()) { Builder.CreateRetVoid(); } else { @@ -1379,8 +1720,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { // Replace G with an alias to F and delete G. void MergeFunctions::writeAlias(Function *F, Function *G) { - PointerType *PTy = G->getType(); - auto *GA = GlobalAlias::create(PTy, G->getLinkage(), "", F); + auto *GA = GlobalAlias::create(G->getLinkage(), "", F); F->setAlignment(std::max(F->getAlignment(), G->getAlignment())); GA->takeName(G); GA->setVisibility(G->getVisibility()); @@ -1425,19 +1765,24 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { ++NumFunctionsMerged; } -/// Replace function F for function G in the map. -void MergeFunctions::replaceFunctionInTree(FnTreeType::iterator &IterToF, +/// Replace function F by function G. +void MergeFunctions::replaceFunctionInTree(const FunctionNode &FN, Function *G) { - Function *F = IterToF->getFunc(); - - // A total order is already guaranteed otherwise because we process strong - // functions before weak functions. - assert(((F->mayBeOverridden() && G->mayBeOverridden()) || - (!F->mayBeOverridden() && !G->mayBeOverridden())) && - "Only change functions if both are strong or both are weak"); - (void)F; - - IterToF->replaceBy(G); + Function *F = FN.getFunc(); + assert(FunctionComparator(F, G, &GlobalNumbers).compare() == 0 && + "The two functions must be equal"); + + auto I = FNodesInTree.find(F); + assert(I != FNodesInTree.end() && "F should be in FNodesInTree"); + assert(FNodesInTree.count(G) == 0 && "FNodesInTree should not contain G"); + + FnTreeType::iterator IterToFNInFnTree = I->second; + assert(&(*IterToFNInFnTree) == &FN && "F should map to FN in FNodesInTree."); + // Remove F -> FN and insert G -> FN + FNodesInTree.erase(I); + FNodesInTree.insert({G, IterToFNInFnTree}); + // Replace F with G in FN, which is stored inside the FnTree. + FN.replaceBy(G); } // Insert a ComparableFunction into the FnTree, or merge it away if equal to one @@ -1447,6 +1792,8 @@ bool MergeFunctions::insert(Function *NewFunction) { FnTree.insert(FunctionNode(NewFunction)); if (Result.second) { + assert(FNodesInTree.count(NewFunction) == 0); + FNodesInTree.insert({NewFunction, Result.first}); DEBUG(dbgs() << "Inserting as unique: " << NewFunction->getName() << '\n'); return false; } @@ -1476,7 +1823,7 @@ bool MergeFunctions::insert(Function *NewFunction) { if (OldF.getFunc()->getName() > NewFunction->getName()) { // Swap the two functions. Function *F = OldF.getFunc(); - replaceFunctionInTree(Result.first, NewFunction); + replaceFunctionInTree(*Result.first, NewFunction); NewFunction = F; assert(OldF.getFunc() != F && "Must have swapped the functions."); } @@ -1495,18 +1842,13 @@ bool MergeFunctions::insert(Function *NewFunction) { // Remove a function from FnTree. If it was already in FnTree, add // it to Deferred so that we'll look at it in the next round. void MergeFunctions::remove(Function *F) { - // We need to make sure we remove F, not a function "equal" to F per the - // function equality comparator. - FnTreeType::iterator found = FnTree.find(FunctionNode(F)); - size_t Erased = 0; - if (found != FnTree.end() && found->getFunc() == F) { - Erased = 1; - FnTree.erase(found); - } - - if (Erased) { - DEBUG(dbgs() << "Removed " << F->getName() - << " from set and deferred it.\n"); + auto I = FNodesInTree.find(F); + if (I != FNodesInTree.end()) { + DEBUG(dbgs() << "Deferred " << F->getName()<< ".\n"); + FnTree.erase(I->second); + // I->second has been invalidated, remove it from the FNodesInTree map to + // preserve the invariant. + FNodesInTree.erase(I); Deferred.emplace_back(F); } } @@ -1516,6 +1858,8 @@ void MergeFunctions::remove(Function *F) { void MergeFunctions::removeUsers(Value *V) { std::vector<Value *> Worklist; Worklist.push_back(V); + SmallSet<Value*, 8> Visited; + Visited.insert(V); while (!Worklist.empty()) { Value *V = Worklist.back(); Worklist.pop_back(); @@ -1526,8 +1870,10 @@ void MergeFunctions::removeUsers(Value *V) { } else if (isa<GlobalValue>(U)) { // do nothing } else if (Constant *C = dyn_cast<Constant>(U)) { - for (User *UU : C->users()) - Worklist.push_back(UU); + for (User *UU : C->users()) { + if (!Visited.insert(UU).second) + Worklist.push_back(UU); + } } } } |