diff options
Diffstat (limited to 'contrib/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp')
-rw-r--r-- | contrib/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp | 766 |
1 files changed, 624 insertions, 142 deletions
diff --git a/contrib/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/contrib/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index 844cc0f..00769cd 100644 --- a/contrib/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/contrib/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -25,6 +25,20 @@ // returns 0, or a single vtable's function returns 1, replace each virtual // call with a comparison of the vptr against that vtable's address. // +// This pass is intended to be used during the regular and thin LTO pipelines. +// During regular LTO, the pass determines the best optimization for each +// virtual call and applies the resolutions directly to virtual calls that are +// eligible for virtual call optimization (i.e. calls that use either of the +// llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). During +// ThinLTO, the pass operates in two phases: +// - Export phase: this is run during the thin link over a single merged module +// that contains all vtables with !type metadata that participate in the link. +// The pass computes a resolution for each virtual call and stores it in the +// type identifier summary. +// - Import phase: this is run during the thin backends over the individual +// modules. The pass applies the resolutions previously computed during the +// import phase to each eligible virtual call. +// //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/WholeProgramDevirt.h" @@ -32,9 +46,11 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/iterator_range.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" @@ -54,12 +70,16 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/ModuleSummaryIndexYAML.h" #include "llvm/Pass.h" #include "llvm/PassRegistry.h" #include "llvm/PassSupport.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" #include "llvm/Support/MathExtras.h" #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/Transforms/Utils/Evaluator.h" #include <algorithm> #include <cstddef> @@ -72,6 +92,26 @@ using namespace wholeprogramdevirt; #define DEBUG_TYPE "wholeprogramdevirt" +static cl::opt<PassSummaryAction> ClSummaryAction( + "wholeprogramdevirt-summary-action", + cl::desc("What to do with the summary when running this pass"), + cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"), + clEnumValN(PassSummaryAction::Import, "import", + "Import typeid resolutions from summary and globals"), + clEnumValN(PassSummaryAction::Export, "export", + "Export typeid resolutions to summary and globals")), + cl::Hidden); + +static cl::opt<std::string> ClReadSummary( + "wholeprogramdevirt-read-summary", + cl::desc("Read summary from given YAML file before running pass"), + cl::Hidden); + +static cl::opt<std::string> ClWriteSummary( + "wholeprogramdevirt-write-summary", + cl::desc("Write summary to given YAML file after running pass"), + cl::Hidden); + // Find the minimum offset that we may store a value of size Size bits at. If // IsAfter is set, look for an offset before the object, otherwise look for an // offset after the object. @@ -259,15 +299,92 @@ struct VirtualCallSite { } }; +// Call site information collected for a specific VTableSlot and possibly a list +// of constant integer arguments. The grouping by arguments is handled by the +// VTableSlotInfo class. +struct CallSiteInfo { + /// The set of call sites for this slot. Used during regular LTO and the + /// import phase of ThinLTO (as well as the export phase of ThinLTO for any + /// call sites that appear in the merged module itself); in each of these + /// cases we are directly operating on the call sites at the IR level. + std::vector<VirtualCallSite> CallSites; + + // These fields are used during the export phase of ThinLTO and reflect + // information collected from function summaries. + + /// Whether any function summary contains an llvm.assume(llvm.type.test) for + /// this slot. + bool SummaryHasTypeTestAssumeUsers; + + /// CFI-specific: a vector containing the list of function summaries that use + /// the llvm.type.checked.load intrinsic and therefore will require + /// resolutions for llvm.type.test in order to implement CFI checks if + /// devirtualization was unsuccessful. If devirtualization was successful, the + /// pass will clear this vector by calling markDevirt(). If at the end of the + /// pass the vector is non-empty, we will need to add a use of llvm.type.test + /// to each of the function summaries in the vector. + std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers; + + bool isExported() const { + return SummaryHasTypeTestAssumeUsers || + !SummaryTypeCheckedLoadUsers.empty(); + } + + /// As explained in the comment for SummaryTypeCheckedLoadUsers. + void markDevirt() { SummaryTypeCheckedLoadUsers.clear(); } +}; + +// Call site information collected for a specific VTableSlot. +struct VTableSlotInfo { + // The set of call sites which do not have all constant integer arguments + // (excluding "this"). + CallSiteInfo CSInfo; + + // The set of call sites with all constant integer arguments (excluding + // "this"), grouped by argument list. + std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo; + + void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses); + +private: + CallSiteInfo &findCallSiteInfo(CallSite CS); +}; + +CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) { + std::vector<uint64_t> Args; + auto *CI = dyn_cast<IntegerType>(CS.getType()); + if (!CI || CI->getBitWidth() > 64 || CS.arg_empty()) + return CSInfo; + for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) { + auto *CI = dyn_cast<ConstantInt>(Arg); + if (!CI || CI->getBitWidth() > 64) + return CSInfo; + Args.push_back(CI->getZExtValue()); + } + return ConstCSInfo[Args]; +} + +void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS, + unsigned *NumUnsafeUses) { + findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses}); +} + struct DevirtModule { Module &M; + function_ref<AAResults &(Function &)> AARGetter; + + ModuleSummaryIndex *ExportSummary; + const ModuleSummaryIndex *ImportSummary; + IntegerType *Int8Ty; PointerType *Int8PtrTy; IntegerType *Int32Ty; + IntegerType *Int64Ty; + IntegerType *IntPtrTy; bool RemarksEnabled; - MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots; + MapVector<VTableSlot, VTableSlotInfo> CallSlots; // This map keeps track of the number of "unsafe" uses of a loaded function // pointer. The key is the associated llvm.type.test intrinsic call generated @@ -279,11 +396,18 @@ struct DevirtModule { // true. std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; - DevirtModule(Module &M) - : M(M), Int8Ty(Type::getInt8Ty(M.getContext())), + DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter, + ModuleSummaryIndex *ExportSummary, + const ModuleSummaryIndex *ImportSummary) + : M(M), AARGetter(AARGetter), ExportSummary(ExportSummary), + ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())), Int8PtrTy(Type::getInt8PtrTy(M.getContext())), Int32Ty(Type::getInt32Ty(M.getContext())), - RemarksEnabled(areRemarksEnabled()) {} + Int64Ty(Type::getInt64Ty(M.getContext())), + IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)), + RemarksEnabled(areRemarksEnabled()) { + assert(!(ExportSummary && ImportSummary)); + } bool areRemarksEnabled(); @@ -298,57 +422,169 @@ struct DevirtModule { tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset); + + void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn, + bool &IsExported); bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, - MutableArrayRef<VirtualCallSite> CallSites); + VTableSlotInfo &SlotInfo, + WholeProgramDevirtResolution *Res); + bool tryEvaluateFunctionsWithArgs( MutableArrayRef<VirtualCallTarget> TargetsForSlot, - ArrayRef<ConstantInt *> Args); - bool tryUniformRetValOpt(IntegerType *RetType, - MutableArrayRef<VirtualCallTarget> TargetsForSlot, - MutableArrayRef<VirtualCallSite> CallSites); + ArrayRef<uint64_t> Args); + + void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, + uint64_t TheRetVal); + bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, + CallSiteInfo &CSInfo, + WholeProgramDevirtResolution::ByArg *Res); + + // Returns the global symbol name that is used to export information about the + // given vtable slot and list of arguments. + std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args, + StringRef Name); + + // This function is called during the export phase to create a symbol + // definition containing information about the given vtable slot and list of + // arguments. + void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, + Constant *C); + + // This function is called during the import phase to create a reference to + // the symbol definition created during the export phase. + Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, + StringRef Name, unsigned AbsWidth = 0); + + void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne, + Constant *UniqueMemberAddr); bool tryUniqueRetValOpt(unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot, - MutableArrayRef<VirtualCallSite> CallSites); + CallSiteInfo &CSInfo, + WholeProgramDevirtResolution::ByArg *Res, + VTableSlot Slot, ArrayRef<uint64_t> Args); + + void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, + Constant *Byte, Constant *Bit); bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot, - ArrayRef<VirtualCallSite> CallSites); + VTableSlotInfo &SlotInfo, + WholeProgramDevirtResolution *Res, VTableSlot Slot); void rebuildGlobal(VTableBits &B); + // Apply the summary resolution for Slot to all virtual calls in SlotInfo. + void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo); + + // If we were able to eliminate all unsafe uses for a type checked load, + // eliminate the associated type tests by replacing them with true. + void removeRedundantTypeTests(); + bool run(); + + // Lower the module using the action and summary passed as command line + // arguments. For testing purposes only. + static bool runForTesting(Module &M, + function_ref<AAResults &(Function &)> AARGetter); }; struct WholeProgramDevirt : public ModulePass { static char ID; - WholeProgramDevirt() : ModulePass(ID) { + bool UseCommandLine = false; + + ModuleSummaryIndex *ExportSummary; + const ModuleSummaryIndex *ImportSummary; + + WholeProgramDevirt() : ModulePass(ID), UseCommandLine(true) { + initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); + } + + WholeProgramDevirt(ModuleSummaryIndex *ExportSummary, + const ModuleSummaryIndex *ImportSummary) + : ModulePass(ID), ExportSummary(ExportSummary), + ImportSummary(ImportSummary) { initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); } bool runOnModule(Module &M) override { if (skipModule(M)) return false; + if (UseCommandLine) + return DevirtModule::runForTesting(M, LegacyAARGetter(*this)); + return DevirtModule(M, LegacyAARGetter(*this), ExportSummary, ImportSummary) + .run(); + } - return DevirtModule(M).run(); + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); } }; } // end anonymous namespace -INITIALIZE_PASS(WholeProgramDevirt, "wholeprogramdevirt", - "Whole program devirtualization", false, false) +INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt", + "Whole program devirtualization", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt", + "Whole program devirtualization", false, false) char WholeProgramDevirt::ID = 0; -ModulePass *llvm::createWholeProgramDevirtPass() { - return new WholeProgramDevirt; +ModulePass * +llvm::createWholeProgramDevirtPass(ModuleSummaryIndex *ExportSummary, + const ModuleSummaryIndex *ImportSummary) { + return new WholeProgramDevirt(ExportSummary, ImportSummary); } PreservedAnalyses WholeProgramDevirtPass::run(Module &M, - ModuleAnalysisManager &) { - if (!DevirtModule(M).run()) + ModuleAnalysisManager &AM) { + auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto AARGetter = [&](Function &F) -> AAResults & { + return FAM.getResult<AAManager>(F); + }; + if (!DevirtModule(M, AARGetter, nullptr, nullptr).run()) return PreservedAnalyses::all(); return PreservedAnalyses::none(); } +bool DevirtModule::runForTesting( + Module &M, function_ref<AAResults &(Function &)> AARGetter) { + ModuleSummaryIndex Summary; + + // Handle the command-line summary arguments. This code is for testing + // purposes only, so we handle errors directly. + if (!ClReadSummary.empty()) { + ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary + + ": "); + auto ReadSummaryFile = + ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary))); + + yaml::Input In(ReadSummaryFile->getBuffer()); + In >> Summary; + ExitOnErr(errorCodeToError(In.error())); + } + + bool Changed = + DevirtModule( + M, AARGetter, + ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr, + ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr) + .run(); + + if (!ClWriteSummary.empty()) { + ExitOnError ExitOnErr( + "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": "); + std::error_code EC; + raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text); + ExitOnErr(errorCodeToError(EC)); + + yaml::Output Out(OS); + Out << Summary; + } + + return Changed; +} + void DevirtModule::buildTypeIdentifierMap( std::vector<VTableBits> &Bits, DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { @@ -443,9 +679,31 @@ bool DevirtModule::tryFindVirtualCallTargets( return !TargetsForSlot.empty(); } +void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, + Constant *TheFn, bool &IsExported) { + auto Apply = [&](CallSiteInfo &CSInfo) { + for (auto &&VCallSite : CSInfo.CallSites) { + if (RemarksEnabled) + VCallSite.emitRemark("single-impl", TheFn->getName()); + VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( + TheFn, VCallSite.CS.getCalledValue()->getType())); + // This use is no longer unsafe. + if (VCallSite.NumUnsafeUses) + --*VCallSite.NumUnsafeUses; + } + if (CSInfo.isExported()) { + IsExported = true; + CSInfo.markDevirt(); + } + }; + Apply(SlotInfo.CSInfo); + for (auto &P : SlotInfo.ConstCSInfo) + Apply(P.second); +} + bool DevirtModule::trySingleImplDevirt( MutableArrayRef<VirtualCallTarget> TargetsForSlot, - MutableArrayRef<VirtualCallSite> CallSites) { + VTableSlotInfo &SlotInfo, WholeProgramDevirtResolution *Res) { // See if the program contains a single implementation of this virtual // function. Function *TheFn = TargetsForSlot[0].Fn; @@ -453,39 +711,51 @@ bool DevirtModule::trySingleImplDevirt( if (TheFn != Target.Fn) return false; + // If so, update each call site to call that implementation directly. if (RemarksEnabled) TargetsForSlot[0].WasDevirt = true; - // If so, update each call site to call that implementation directly. - for (auto &&VCallSite : CallSites) { - if (RemarksEnabled) - VCallSite.emitRemark("single-impl", TheFn->getName()); - VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( - TheFn, VCallSite.CS.getCalledValue()->getType())); - // This use is no longer unsafe. - if (VCallSite.NumUnsafeUses) - --*VCallSite.NumUnsafeUses; + + bool IsExported = false; + applySingleImplDevirt(SlotInfo, TheFn, IsExported); + if (!IsExported) + return false; + + // If the only implementation has local linkage, we must promote to external + // to make it visible to thin LTO objects. We can only get here during the + // ThinLTO export phase. + if (TheFn->hasLocalLinkage()) { + TheFn->setLinkage(GlobalValue::ExternalLinkage); + TheFn->setVisibility(GlobalValue::HiddenVisibility); + TheFn->setName(TheFn->getName() + "$merged"); } + + Res->TheKind = WholeProgramDevirtResolution::SingleImpl; + Res->SingleImplName = TheFn->getName(); + return true; } bool DevirtModule::tryEvaluateFunctionsWithArgs( MutableArrayRef<VirtualCallTarget> TargetsForSlot, - ArrayRef<ConstantInt *> Args) { + ArrayRef<uint64_t> Args) { // Evaluate each function and store the result in each target's RetVal // field. for (VirtualCallTarget &Target : TargetsForSlot) { if (Target.Fn->arg_size() != Args.size() + 1) return false; - for (unsigned I = 0; I != Args.size(); ++I) - if (Target.Fn->getFunctionType()->getParamType(I + 1) != - Args[I]->getType()) - return false; Evaluator Eval(M.getDataLayout(), nullptr); SmallVector<Constant *, 2> EvalArgs; EvalArgs.push_back( Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0))); - EvalArgs.insert(EvalArgs.end(), Args.begin(), Args.end()); + for (unsigned I = 0; I != Args.size(); ++I) { + auto *ArgTy = dyn_cast<IntegerType>( + Target.Fn->getFunctionType()->getParamType(I + 1)); + if (!ArgTy) + return false; + EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I])); + } + Constant *RetVal; if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) || !isa<ConstantInt>(RetVal)) @@ -495,9 +765,18 @@ bool DevirtModule::tryEvaluateFunctionsWithArgs( return true; } +void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, + uint64_t TheRetVal) { + for (auto Call : CSInfo.CallSites) + Call.replaceAndErase( + "uniform-ret-val", FnName, RemarksEnabled, + ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal)); + CSInfo.markDevirt(); +} + bool DevirtModule::tryUniformRetValOpt( - IntegerType *RetType, MutableArrayRef<VirtualCallTarget> TargetsForSlot, - MutableArrayRef<VirtualCallSite> CallSites) { + MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo, + WholeProgramDevirtResolution::ByArg *Res) { // Uniform return value optimization. If all functions return the same // constant, replace all calls with that constant. uint64_t TheRetVal = TargetsForSlot[0].RetVal; @@ -505,19 +784,77 @@ bool DevirtModule::tryUniformRetValOpt( if (Target.RetVal != TheRetVal) return false; - auto TheRetValConst = ConstantInt::get(RetType, TheRetVal); - for (auto Call : CallSites) - Call.replaceAndErase("uniform-ret-val", TargetsForSlot[0].Fn->getName(), - RemarksEnabled, TheRetValConst); + if (CSInfo.isExported()) { + Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal; + Res->Info = TheRetVal; + } + + applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal); if (RemarksEnabled) for (auto &&Target : TargetsForSlot) Target.WasDevirt = true; return true; } +std::string DevirtModule::getGlobalName(VTableSlot Slot, + ArrayRef<uint64_t> Args, + StringRef Name) { + std::string FullName = "__typeid_"; + raw_string_ostream OS(FullName); + OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset; + for (uint64_t Arg : Args) + OS << '_' << Arg; + OS << '_' << Name; + return OS.str(); +} + +void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, + StringRef Name, Constant *C) { + GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage, + getGlobalName(Slot, Args, Name), C, &M); + GA->setVisibility(GlobalValue::HiddenVisibility); +} + +Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, + StringRef Name, unsigned AbsWidth) { + Constant *C = M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Ty); + auto *GV = dyn_cast<GlobalVariable>(C); + // We only need to set metadata if the global is newly created, in which + // case it would not have hidden visibility. + if (!GV || GV->getVisibility() == GlobalValue::HiddenVisibility) + return C; + + GV->setVisibility(GlobalValue::HiddenVisibility); + auto SetAbsRange = [&](uint64_t Min, uint64_t Max) { + auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min)); + auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max)); + GV->setMetadata(LLVMContext::MD_absolute_symbol, + MDNode::get(M.getContext(), {MinC, MaxC})); + }; + if (AbsWidth == IntPtrTy->getBitWidth()) + SetAbsRange(~0ull, ~0ull); // Full set. + else if (AbsWidth) + SetAbsRange(0, 1ull << AbsWidth); + return GV; +} + +void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, + bool IsOne, + Constant *UniqueMemberAddr) { + for (auto &&Call : CSInfo.CallSites) { + IRBuilder<> B(Call.CS.getInstruction()); + Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, + Call.VTable, UniqueMemberAddr); + Cmp = B.CreateZExt(Cmp, Call.CS->getType()); + Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, Cmp); + } + CSInfo.markDevirt(); +} + bool DevirtModule::tryUniqueRetValOpt( unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot, - MutableArrayRef<VirtualCallSite> CallSites) { + CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res, + VTableSlot Slot, ArrayRef<uint64_t> Args) { // IsOne controls whether we look for a 0 or a 1. auto tryUniqueRetValOptFor = [&](bool IsOne) { const TypeMemberInfo *UniqueMember = nullptr; @@ -533,16 +870,23 @@ bool DevirtModule::tryUniqueRetValOpt( // checked for a uniform return value in tryUniformRetValOpt. assert(UniqueMember); - // Replace each call with the comparison. - for (auto &&Call : CallSites) { - IRBuilder<> B(Call.CS.getInstruction()); - Value *OneAddr = B.CreateBitCast(UniqueMember->Bits->GV, Int8PtrTy); - OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueMember->Offset); - Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, - Call.VTable, OneAddr); - Call.replaceAndErase("unique-ret-val", TargetsForSlot[0].Fn->getName(), - RemarksEnabled, Cmp); + Constant *UniqueMemberAddr = + ConstantExpr::getBitCast(UniqueMember->Bits->GV, Int8PtrTy); + UniqueMemberAddr = ConstantExpr::getGetElementPtr( + Int8Ty, UniqueMemberAddr, + ConstantInt::get(Int64Ty, UniqueMember->Offset)); + + if (CSInfo.isExported()) { + Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal; + Res->Info = IsOne; + + exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr); } + + // Replace each call with the comparison. + applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne, + UniqueMemberAddr); + // Update devirtualization statistics for targets. if (RemarksEnabled) for (auto &&Target : TargetsForSlot) @@ -560,9 +904,30 @@ bool DevirtModule::tryUniqueRetValOpt( return false; } +void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, + Constant *Byte, Constant *Bit) { + for (auto Call : CSInfo.CallSites) { + auto *RetType = cast<IntegerType>(Call.CS.getType()); + IRBuilder<> B(Call.CS.getInstruction()); + Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte); + if (RetType->getBitWidth() == 1) { + Value *Bits = B.CreateLoad(Addr); + Value *BitsAndBit = B.CreateAnd(Bits, Bit); + auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); + Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled, + IsBitSet); + } else { + Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); + Value *Val = B.CreateLoad(RetType, ValAddr); + Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, Val); + } + } + CSInfo.markDevirt(); +} + bool DevirtModule::tryVirtualConstProp( - MutableArrayRef<VirtualCallTarget> TargetsForSlot, - ArrayRef<VirtualCallSite> CallSites) { + MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, + WholeProgramDevirtResolution *Res, VTableSlot Slot) { // This only works if the function returns an integer. auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType()); if (!RetType) @@ -571,55 +936,38 @@ bool DevirtModule::tryVirtualConstProp( if (BitWidth > 64) return false; - // Make sure that each function does not access memory, takes at least one - // argument, does not use its first argument (which we assume is 'this'), - // and has the same return type. + // Make sure that each function is defined, does not access memory, takes at + // least one argument, does not use its first argument (which we assume is + // 'this'), and has the same return type. + // + // Note that we test whether this copy of the function is readnone, rather + // than testing function attributes, which must hold for any copy of the + // function, even a less optimized version substituted at link time. This is + // sound because the virtual constant propagation optimizations effectively + // inline all implementations of the virtual function into each call site, + // rather than using function attributes to perform local optimization. for (VirtualCallTarget &Target : TargetsForSlot) { - if (!Target.Fn->doesNotAccessMemory() || Target.Fn->arg_empty() || - !Target.Fn->arg_begin()->use_empty() || + if (Target.Fn->isDeclaration() || + computeFunctionBodyMemoryAccess(*Target.Fn, AARGetter(*Target.Fn)) != + MAK_ReadNone || + Target.Fn->arg_empty() || !Target.Fn->arg_begin()->use_empty() || Target.Fn->getReturnType() != RetType) return false; } - // Group call sites by the list of constant arguments they pass. - // The comparator ensures deterministic ordering. - struct ByAPIntValue { - bool operator()(const std::vector<ConstantInt *> &A, - const std::vector<ConstantInt *> &B) const { - return std::lexicographical_compare( - A.begin(), A.end(), B.begin(), B.end(), - [](ConstantInt *AI, ConstantInt *BI) { - return AI->getValue().ult(BI->getValue()); - }); - } - }; - std::map<std::vector<ConstantInt *>, std::vector<VirtualCallSite>, - ByAPIntValue> - VCallSitesByConstantArg; - for (auto &&VCallSite : CallSites) { - std::vector<ConstantInt *> Args; - if (VCallSite.CS.getType() != RetType) - continue; - for (auto &&Arg : - make_range(VCallSite.CS.arg_begin() + 1, VCallSite.CS.arg_end())) { - if (!isa<ConstantInt>(Arg)) - break; - Args.push_back(cast<ConstantInt>(&Arg)); - } - if (Args.size() + 1 != VCallSite.CS.arg_size()) - continue; - - VCallSitesByConstantArg[Args].push_back(VCallSite); - } - - for (auto &&CSByConstantArg : VCallSitesByConstantArg) { + for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) { if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first)) continue; - if (tryUniformRetValOpt(RetType, TargetsForSlot, CSByConstantArg.second)) + WholeProgramDevirtResolution::ByArg *ResByArg = nullptr; + if (Res) + ResByArg = &Res->ResByArg[CSByConstantArg.first]; + + if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg)) continue; - if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second)) + if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second, + ResByArg, Slot, CSByConstantArg.first)) continue; // Find an allocation offset in bits in all vtables associated with the @@ -659,26 +1007,20 @@ bool DevirtModule::tryVirtualConstProp( for (auto &&Target : TargetsForSlot) Target.WasDevirt = true; - // Rewrite each call to a load from OffsetByte/OffsetBit. - for (auto Call : CSByConstantArg.second) { - IRBuilder<> B(Call.CS.getInstruction()); - Value *Addr = B.CreateConstGEP1_64(Call.VTable, OffsetByte); - if (BitWidth == 1) { - Value *Bits = B.CreateLoad(Addr); - Value *Bit = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); - Value *BitsAndBit = B.CreateAnd(Bits, Bit); - auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); - Call.replaceAndErase("virtual-const-prop-1-bit", - TargetsForSlot[0].Fn->getName(), - RemarksEnabled, IsBitSet); - } else { - Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); - Value *Val = B.CreateLoad(RetType, ValAddr); - Call.replaceAndErase("virtual-const-prop", - TargetsForSlot[0].Fn->getName(), - RemarksEnabled, Val); - } + Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte); + Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); + + if (CSByConstantArg.second.isExported()) { + ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp; + exportGlobal(Slot, CSByConstantArg.first, "byte", + ConstantExpr::getIntToPtr(ByteConst, Int8PtrTy)); + exportGlobal(Slot, CSByConstantArg.first, "bit", + ConstantExpr::getIntToPtr(BitConst, Int8PtrTy)); } + + // Rewrite each call to a load from OffsetByte/OffsetBit. + applyVirtualConstProp(CSByConstantArg.second, + TargetsForSlot[0].Fn->getName(), ByteConst, BitConst); } return true; } @@ -733,7 +1075,11 @@ bool DevirtModule::areRemarksEnabled() { if (FL.empty()) return false; const Function &Fn = FL.front(); - auto DI = OptimizationRemark(DEBUG_TYPE, Fn, DebugLoc(), ""); + + const auto &BBL = Fn.getBasicBlockList(); + if (BBL.empty()) + return false; + auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front()); return DI.isEnabled(); } @@ -766,8 +1112,8 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc, Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); if (SeenPtrs.insert(Ptr).second) { for (DevirtCallSite Call : DevirtCalls) { - CallSlots[{TypeId, Call.Offset}].push_back( - {CI->getArgOperand(0), Call.CS, nullptr}); + CallSlots[{TypeId, Call.Offset}].addCallSite(CI->getArgOperand(0), + Call.CS, nullptr); } } } @@ -853,14 +1199,79 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { if (HasNonCallUses) ++NumUnsafeUses; for (DevirtCallSite Call : DevirtCalls) { - CallSlots[{TypeId, Call.Offset}].push_back( - {Ptr, Call.CS, &NumUnsafeUses}); + CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, + &NumUnsafeUses); } CI->eraseFromParent(); } } +void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) { + const TypeIdSummary *TidSummary = + ImportSummary->getTypeIdSummary(cast<MDString>(Slot.TypeID)->getString()); + if (!TidSummary) + return; + auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset); + if (ResI == TidSummary->WPDRes.end()) + return; + const WholeProgramDevirtResolution &Res = ResI->second; + + if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) { + // The type of the function in the declaration is irrelevant because every + // call site will cast it to the correct type. + auto *SingleImpl = M.getOrInsertFunction( + Res.SingleImplName, Type::getVoidTy(M.getContext())); + + // This is the import phase so we should not be exporting anything. + bool IsExported = false; + applySingleImplDevirt(SlotInfo, SingleImpl, IsExported); + assert(!IsExported); + } + + for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) { + auto I = Res.ResByArg.find(CSByConstantArg.first); + if (I == Res.ResByArg.end()) + continue; + auto &ResByArg = I->second; + // FIXME: We should figure out what to do about the "function name" argument + // to the apply* functions, as the function names are unavailable during the + // importing phase. For now we just pass the empty string. This does not + // impact correctness because the function names are just used for remarks. + switch (ResByArg.TheKind) { + case WholeProgramDevirtResolution::ByArg::UniformRetVal: + applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info); + break; + case WholeProgramDevirtResolution::ByArg::UniqueRetVal: { + Constant *UniqueMemberAddr = + importGlobal(Slot, CSByConstantArg.first, "unique_member"); + applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info, + UniqueMemberAddr); + break; + } + case WholeProgramDevirtResolution::ByArg::VirtualConstProp: { + Constant *Byte = importGlobal(Slot, CSByConstantArg.first, "byte", 32); + Byte = ConstantExpr::getPtrToInt(Byte, Int32Ty); + Constant *Bit = importGlobal(Slot, CSByConstantArg.first, "bit", 8); + Bit = ConstantExpr::getPtrToInt(Bit, Int8Ty); + applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit); + } + default: + break; + } + } +} + +void DevirtModule::removeRedundantTypeTests() { + auto True = ConstantInt::getTrue(M.getContext()); + for (auto &&U : NumUnsafeUsesForTypeTest) { + if (U.second == 0) { + U.first->replaceAllUsesWith(True); + U.first->eraseFromParent(); + } + } +} + bool DevirtModule::run() { Function *TypeTestFunc = M.getFunction(Intrinsic::getName(Intrinsic::type_test)); @@ -868,7 +1279,11 @@ bool DevirtModule::run() { M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); - if ((!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || + // Normally if there are no users of the devirtualization intrinsics in the + // module, this pass has nothing to do. But if we are exporting, we also need + // to handle any users that appear only in the function summaries. + if (!ExportSummary && + (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || AssumeFunc->use_empty()) && (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) return false; @@ -879,6 +1294,17 @@ bool DevirtModule::run() { if (TypeCheckedLoadFunc) scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); + if (ImportSummary) { + for (auto &S : CallSlots) + importResolution(S.first, S.second); + + removeRedundantTypeTests(); + + // The rest of the code is only necessary when exporting or during regular + // LTO, so we are done. + return true; + } + // Rebuild type metadata into a map for easy lookup. std::vector<VTableBits> Bits; DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; @@ -886,6 +1312,53 @@ bool DevirtModule::run() { if (TypeIdMap.empty()) return true; + // Collect information from summary about which calls to try to devirtualize. + if (ExportSummary) { + DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID; + for (auto &P : TypeIdMap) { + if (auto *TypeId = dyn_cast<MDString>(P.first)) + MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back( + TypeId); + } + + for (auto &P : *ExportSummary) { + for (auto &S : P.second.SummaryList) { + auto *FS = dyn_cast<FunctionSummary>(S.get()); + if (!FS) + continue; + // FIXME: Only add live functions. + for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { + for (Metadata *MD : MetadataByGUID[VF.GUID]) { + CallSlots[{MD, VF.Offset}].CSInfo.SummaryHasTypeTestAssumeUsers = + true; + } + } + for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { + for (Metadata *MD : MetadataByGUID[VF.GUID]) { + CallSlots[{MD, VF.Offset}] + .CSInfo.SummaryTypeCheckedLoadUsers.push_back(FS); + } + } + for (const FunctionSummary::ConstVCall &VC : + FS->type_test_assume_const_vcalls()) { + for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { + CallSlots[{MD, VC.VFunc.Offset}] + .ConstCSInfo[VC.Args] + .SummaryHasTypeTestAssumeUsers = true; + } + } + for (const FunctionSummary::ConstVCall &VC : + FS->type_checked_load_const_vcalls()) { + for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { + CallSlots[{MD, VC.VFunc.Offset}] + .ConstCSInfo[VC.Args] + .SummaryTypeCheckedLoadUsers.push_back(FS); + } + } + } + } + } + // For each (type, offset) pair: bool DidVirtualConstProp = false; std::map<std::string, Function*> DevirtTargets; @@ -894,19 +1367,39 @@ bool DevirtModule::run() { // function implementation at offset S.first.ByteOffset, and add to // TargetsForSlot. std::vector<VirtualCallTarget> TargetsForSlot; - if (!tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID], - S.first.ByteOffset)) - continue; - - if (!trySingleImplDevirt(TargetsForSlot, S.second) && - tryVirtualConstProp(TargetsForSlot, S.second)) + if (tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID], + S.first.ByteOffset)) { + WholeProgramDevirtResolution *Res = nullptr; + if (ExportSummary && isa<MDString>(S.first.TypeID)) + Res = &ExportSummary + ->getOrInsertTypeIdSummary( + cast<MDString>(S.first.TypeID)->getString()) + .WPDRes[S.first.ByteOffset]; + + if (!trySingleImplDevirt(TargetsForSlot, S.second, Res) && + tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first)) DidVirtualConstProp = true; - // Collect functions devirtualized at least for one call site for stats. - if (RemarksEnabled) - for (const auto &T : TargetsForSlot) - if (T.WasDevirt) - DevirtTargets[T.Fn->getName()] = T.Fn; + // Collect functions devirtualized at least for one call site for stats. + if (RemarksEnabled) + for (const auto &T : TargetsForSlot) + if (T.WasDevirt) + DevirtTargets[T.Fn->getName()] = T.Fn; + } + + // CFI-specific: if we are exporting and any llvm.type.checked.load + // intrinsics were *not* devirtualized, we need to add the resulting + // llvm.type.test intrinsics to the function summaries so that the + // LowerTypeTests pass will export them. + if (ExportSummary && isa<MDString>(S.first.TypeID)) { + auto GUID = + GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString()); + for (auto FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers) + FS->addTypeTest(GUID); + for (auto &CCS : S.second.ConstCSInfo) + for (auto FS : CCS.second.SummaryTypeCheckedLoadUsers) + FS->addTypeTest(GUID); + } } if (RemarksEnabled) { @@ -914,23 +1407,12 @@ bool DevirtModule::run() { for (const auto &DT : DevirtTargets) { Function *F = DT.second; DISubprogram *SP = F->getSubprogram(); - DebugLoc DL = SP ? DebugLoc::get(SP->getScopeLine(), 0, SP) : DebugLoc(); - emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, DL, + emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, SP, Twine("devirtualized ") + F->getName()); } } - // If we were able to eliminate all unsafe uses for a type checked load, - // eliminate the type test by replacing it with true. - if (TypeCheckedLoadFunc) { - auto True = ConstantInt::getTrue(M.getContext()); - for (auto &&U : NumUnsafeUsesForTypeTest) { - if (U.second == 0) { - U.first->replaceAllUsesWith(True); - U.first->eraseFromParent(); - } - } - } + removeRedundantTypeTests(); // Rebuild each global we touched as part of virtual constant propagation to // include the before and after bytes. |