diff options
Diffstat (limited to 'contrib/llvm/lib/Transforms')
161 files changed, 27740 insertions, 11564 deletions
diff --git a/contrib/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp b/contrib/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp new file mode 100644 index 0000000..a97db6f --- /dev/null +++ b/contrib/llvm/lib/Transforms/Coroutines/CoroCleanup.cpp @@ -0,0 +1,134 @@ +//===- CoroCleanup.cpp - Coroutine Cleanup Pass ---------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// This pass lowers all remaining coroutine intrinsics. +//===----------------------------------------------------------------------===// + +#include "CoroInternal.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" + +using namespace llvm; + +#define DEBUG_TYPE "coro-cleanup" + +namespace { +// Created on demand if CoroCleanup pass has work to do. +struct Lowerer : coro::LowererBase { + IRBuilder<> Builder; + Lowerer(Module &M) : LowererBase(M), Builder(Context) {} + bool lowerRemainingCoroIntrinsics(Function &F); +}; +} + +static void simplifyCFG(Function &F) { + llvm::legacy::FunctionPassManager FPM(F.getParent()); + FPM.add(createCFGSimplificationPass()); + + FPM.doInitialization(); + FPM.run(F); + FPM.doFinalization(); +} + +static void lowerSubFn(IRBuilder<> &Builder, CoroSubFnInst *SubFn) { + Builder.SetInsertPoint(SubFn); + Value *FrameRaw = SubFn->getFrame(); + int Index = SubFn->getIndex(); + + auto *FrameTy = StructType::get( + SubFn->getContext(), {Builder.getInt8PtrTy(), Builder.getInt8PtrTy()}); + PointerType *FramePtrTy = FrameTy->getPointerTo(); + + Builder.SetInsertPoint(SubFn); + auto *FramePtr = Builder.CreateBitCast(FrameRaw, FramePtrTy); + auto *Gep = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, Index); + auto *Load = Builder.CreateLoad(Gep); + + SubFn->replaceAllUsesWith(Load); +} + +bool Lowerer::lowerRemainingCoroIntrinsics(Function &F) { + bool Changed = false; + + for (auto IB = inst_begin(F), E = inst_end(F); IB != E;) { + Instruction &I = *IB++; + if (auto *II = dyn_cast<IntrinsicInst>(&I)) { + switch (II->getIntrinsicID()) { + default: + continue; + case Intrinsic::coro_begin: + II->replaceAllUsesWith(II->getArgOperand(1)); + break; + case Intrinsic::coro_free: + II->replaceAllUsesWith(II->getArgOperand(1)); + break; + case Intrinsic::coro_alloc: + II->replaceAllUsesWith(ConstantInt::getTrue(Context)); + break; + case Intrinsic::coro_id: + II->replaceAllUsesWith(ConstantTokenNone::get(Context)); + break; + case Intrinsic::coro_subfn_addr: + lowerSubFn(Builder, cast<CoroSubFnInst>(II)); + break; + } + II->eraseFromParent(); + Changed = true; + } + } + + if (Changed) { + // After replacement were made we can cleanup the function body a little. + simplifyCFG(F); + } + return Changed; +} + +//===----------------------------------------------------------------------===// +// Top Level Driver +//===----------------------------------------------------------------------===// + +namespace { + +struct CoroCleanup : FunctionPass { + static char ID; // Pass identification, replacement for typeid + + CoroCleanup() : FunctionPass(ID) {} + + std::unique_ptr<Lowerer> L; + + // This pass has work to do only if we find intrinsics we are going to lower + // in the module. + bool doInitialization(Module &M) override { + if (coro::declaresIntrinsics(M, {"llvm.coro.alloc", "llvm.coro.begin", + "llvm.coro.subfn.addr", "llvm.coro.free", + "llvm.coro.id"})) + L = llvm::make_unique<Lowerer>(M); + return false; + } + + bool runOnFunction(Function &F) override { + if (L) + return L->lowerRemainingCoroIntrinsics(F); + return false; + } + void getAnalysisUsage(AnalysisUsage &AU) const override { + if (!L) + AU.setPreservesAll(); + } +}; +} + +char CoroCleanup::ID = 0; +INITIALIZE_PASS(CoroCleanup, "coro-cleanup", + "Lower all coroutine related intrinsics", false, false) + +Pass *llvm::createCoroCleanupPass() { return new CoroCleanup(); } diff --git a/contrib/llvm/lib/Transforms/Coroutines/CoroEarly.cpp b/contrib/llvm/lib/Transforms/Coroutines/CoroEarly.cpp new file mode 100644 index 0000000..e8bb0ca --- /dev/null +++ b/contrib/llvm/lib/Transforms/Coroutines/CoroEarly.cpp @@ -0,0 +1,218 @@ +//===- CoroEarly.cpp - Coroutine Early Function Pass ----------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// This pass lowers coroutine intrinsics that hide the details of the exact +// calling convention for coroutine resume and destroy functions and details of +// the structure of the coroutine frame. +//===----------------------------------------------------------------------===// + +#include "CoroInternal.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" + +using namespace llvm; + +#define DEBUG_TYPE "coro-early" + +namespace { +// Created on demand if CoroEarly pass has work to do. +class Lowerer : public coro::LowererBase { + IRBuilder<> Builder; + PointerType *const AnyResumeFnPtrTy; + + void lowerResumeOrDestroy(CallSite CS, CoroSubFnInst::ResumeKind); + void lowerCoroPromise(CoroPromiseInst *Intrin); + void lowerCoroDone(IntrinsicInst *II); + +public: + Lowerer(Module &M) + : LowererBase(M), Builder(Context), + AnyResumeFnPtrTy(FunctionType::get(Type::getVoidTy(Context), Int8Ptr, + /*isVarArg=*/false) + ->getPointerTo()) {} + bool lowerEarlyIntrinsics(Function &F); +}; +} + +// Replace a direct call to coro.resume or coro.destroy with an indirect call to +// an address returned by coro.subfn.addr intrinsic. This is done so that +// CGPassManager recognizes devirtualization when CoroElide pass replaces a call +// to coro.subfn.addr with an appropriate function address. +void Lowerer::lowerResumeOrDestroy(CallSite CS, + CoroSubFnInst::ResumeKind Index) { + Value *ResumeAddr = + makeSubFnCall(CS.getArgOperand(0), Index, CS.getInstruction()); + CS.setCalledFunction(ResumeAddr); + CS.setCallingConv(CallingConv::Fast); +} + +// Coroutine promise field is always at the fixed offset from the beginning of +// the coroutine frame. i8* coro.promise(i8*, i1 from) intrinsic adds an offset +// to a passed pointer to move from coroutine frame to coroutine promise and +// vice versa. Since we don't know exactly which coroutine frame it is, we build +// a coroutine frame mock up starting with two function pointers, followed by a +// properly aligned coroutine promise field. +// TODO: Handle the case when coroutine promise alloca has align override. +void Lowerer::lowerCoroPromise(CoroPromiseInst *Intrin) { + Value *Operand = Intrin->getArgOperand(0); + unsigned Alignement = Intrin->getAlignment(); + Type *Int8Ty = Builder.getInt8Ty(); + + auto *SampleStruct = + StructType::get(Context, {AnyResumeFnPtrTy, AnyResumeFnPtrTy, Int8Ty}); + const DataLayout &DL = TheModule.getDataLayout(); + int64_t Offset = alignTo( + DL.getStructLayout(SampleStruct)->getElementOffset(2), Alignement); + if (Intrin->isFromPromise()) + Offset = -Offset; + + Builder.SetInsertPoint(Intrin); + Value *Replacement = + Builder.CreateConstInBoundsGEP1_32(Int8Ty, Operand, Offset); + + Intrin->replaceAllUsesWith(Replacement); + Intrin->eraseFromParent(); +} + +// When a coroutine reaches final suspend point, it zeros out ResumeFnAddr in +// the coroutine frame (it is UB to resume from a final suspend point). +// The llvm.coro.done intrinsic is used to check whether a coroutine is +// suspended at the final suspend point or not. +void Lowerer::lowerCoroDone(IntrinsicInst *II) { + Value *Operand = II->getArgOperand(0); + + // ResumeFnAddr is the first pointer sized element of the coroutine frame. + auto *FrameTy = Int8Ptr; + PointerType *FramePtrTy = FrameTy->getPointerTo(); + + Builder.SetInsertPoint(II); + auto *BCI = Builder.CreateBitCast(Operand, FramePtrTy); + auto *Gep = Builder.CreateConstInBoundsGEP1_32(FrameTy, BCI, 0); + auto *Load = Builder.CreateLoad(Gep); + auto *Cond = Builder.CreateICmpEQ(Load, NullPtr); + + II->replaceAllUsesWith(Cond); + II->eraseFromParent(); +} + +// Prior to CoroSplit, calls to coro.begin needs to be marked as NoDuplicate, +// as CoroSplit assumes there is exactly one coro.begin. After CoroSplit, +// NoDuplicate attribute will be removed from coro.begin otherwise, it will +// interfere with inlining. +static void setCannotDuplicate(CoroIdInst *CoroId) { + for (User *U : CoroId->users()) + if (auto *CB = dyn_cast<CoroBeginInst>(U)) + CB->setCannotDuplicate(); +} + +bool Lowerer::lowerEarlyIntrinsics(Function &F) { + bool Changed = false; + CoroIdInst *CoroId = nullptr; + SmallVector<CoroFreeInst *, 4> CoroFrees; + for (auto IB = inst_begin(F), IE = inst_end(F); IB != IE;) { + Instruction &I = *IB++; + if (auto CS = CallSite(&I)) { + switch (CS.getIntrinsicID()) { + default: + continue; + case Intrinsic::coro_free: + CoroFrees.push_back(cast<CoroFreeInst>(&I)); + break; + case Intrinsic::coro_suspend: + // Make sure that final suspend point is not duplicated as CoroSplit + // pass expects that there is at most one final suspend point. + if (cast<CoroSuspendInst>(&I)->isFinal()) + CS.setCannotDuplicate(); + break; + case Intrinsic::coro_end: + // Make sure that fallthrough coro.end is not duplicated as CoroSplit + // pass expects that there is at most one fallthrough coro.end. + if (cast<CoroEndInst>(&I)->isFallthrough()) + CS.setCannotDuplicate(); + break; + case Intrinsic::coro_id: + // Mark a function that comes out of the frontend that has a coro.id + // with a coroutine attribute. + if (auto *CII = cast<CoroIdInst>(&I)) { + if (CII->getInfo().isPreSplit()) { + F.addFnAttr(CORO_PRESPLIT_ATTR, UNPREPARED_FOR_SPLIT); + setCannotDuplicate(CII); + CII->setCoroutineSelf(); + CoroId = cast<CoroIdInst>(&I); + } + } + break; + case Intrinsic::coro_resume: + lowerResumeOrDestroy(CS, CoroSubFnInst::ResumeIndex); + break; + case Intrinsic::coro_destroy: + lowerResumeOrDestroy(CS, CoroSubFnInst::DestroyIndex); + break; + case Intrinsic::coro_promise: + lowerCoroPromise(cast<CoroPromiseInst>(&I)); + break; + case Intrinsic::coro_done: + lowerCoroDone(cast<IntrinsicInst>(&I)); + break; + } + Changed = true; + } + } + // Make sure that all CoroFree reference the coro.id intrinsic. + // Token type is not exposed through coroutine C/C++ builtins to plain C, so + // we allow specifying none and fixing it up here. + if (CoroId) + for (CoroFreeInst *CF : CoroFrees) + CF->setArgOperand(0, CoroId); + return Changed; +} + +//===----------------------------------------------------------------------===// +// Top Level Driver +//===----------------------------------------------------------------------===// + +namespace { + +struct CoroEarly : public FunctionPass { + static char ID; // Pass identification, replacement for typeid. + CoroEarly() : FunctionPass(ID) {} + + std::unique_ptr<Lowerer> L; + + // This pass has work to do only if we find intrinsics we are going to lower + // in the module. + bool doInitialization(Module &M) override { + if (coro::declaresIntrinsics(M, {"llvm.coro.id", "llvm.coro.destroy", + "llvm.coro.done", "llvm.coro.end", + "llvm.coro.free", "llvm.coro.promise", + "llvm.coro.resume", "llvm.coro.suspend"})) + L = llvm::make_unique<Lowerer>(M); + return false; + } + + bool runOnFunction(Function &F) override { + if (!L) + return false; + + return L->lowerEarlyIntrinsics(F); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + } +}; +} + +char CoroEarly::ID = 0; +INITIALIZE_PASS(CoroEarly, "coro-early", "Lower early coroutine intrinsics", + false, false) + +Pass *llvm::createCoroEarlyPass() { return new CoroEarly(); } diff --git a/contrib/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/contrib/llvm/lib/Transforms/Coroutines/CoroElide.cpp new file mode 100644 index 0000000..99974d8 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Coroutines/CoroElide.cpp @@ -0,0 +1,317 @@ +//===- CoroElide.cpp - Coroutine Frame Allocation Elision Pass ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// This pass replaces dynamic allocation of coroutine frame with alloca and +// replaces calls to llvm.coro.resume and llvm.coro.destroy with direct calls +// to coroutine sub-functions. +//===----------------------------------------------------------------------===// + +#include "CoroInternal.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/Pass.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace llvm; + +#define DEBUG_TYPE "coro-elide" + +namespace { +// Created on demand if CoroElide pass has work to do. +struct Lowerer : coro::LowererBase { + SmallVector<CoroIdInst *, 4> CoroIds; + SmallVector<CoroBeginInst *, 1> CoroBegins; + SmallVector<CoroAllocInst *, 1> CoroAllocs; + SmallVector<CoroSubFnInst *, 4> ResumeAddr; + SmallVector<CoroSubFnInst *, 4> DestroyAddr; + SmallVector<CoroFreeInst *, 1> CoroFrees; + + Lowerer(Module &M) : LowererBase(M) {} + + void elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA); + bool shouldElide() const; + bool processCoroId(CoroIdInst *, AAResults &AA); +}; +} // end anonymous namespace + +// Go through the list of coro.subfn.addr intrinsics and replace them with the +// provided constant. +static void replaceWithConstant(Constant *Value, + SmallVectorImpl<CoroSubFnInst *> &Users) { + if (Users.empty()) + return; + + // See if we need to bitcast the constant to match the type of the intrinsic + // being replaced. Note: All coro.subfn.addr intrinsics return the same type, + // so we only need to examine the type of the first one in the list. + Type *IntrTy = Users.front()->getType(); + Type *ValueTy = Value->getType(); + if (ValueTy != IntrTy) { + // May need to tweak the function type to match the type expected at the + // use site. + assert(ValueTy->isPointerTy() && IntrTy->isPointerTy()); + Value = ConstantExpr::getBitCast(Value, IntrTy); + } + + // Now the value type matches the type of the intrinsic. Replace them all! + for (CoroSubFnInst *I : Users) + replaceAndRecursivelySimplify(I, Value); +} + +// See if any operand of the call instruction references the coroutine frame. +static bool operandReferences(CallInst *CI, AllocaInst *Frame, AAResults &AA) { + for (Value *Op : CI->operand_values()) + if (AA.alias(Op, Frame) != NoAlias) + return true; + return false; +} + +// Look for any tail calls referencing the coroutine frame and remove tail +// attribute from them, since now coroutine frame resides on the stack and tail +// call implies that the function does not references anything on the stack. +static void removeTailCallAttribute(AllocaInst *Frame, AAResults &AA) { + Function &F = *Frame->getFunction(); + MemoryLocation Mem(Frame); + for (Instruction &I : instructions(F)) + if (auto *Call = dyn_cast<CallInst>(&I)) + if (Call->isTailCall() && operandReferences(Call, Frame, AA)) { + // FIXME: If we ever hit this check. Evaluate whether it is more + // appropriate to retain musttail and allow the code to compile. + if (Call->isMustTailCall()) + report_fatal_error("Call referring to the coroutine frame cannot be " + "marked as musttail"); + Call->setTailCall(false); + } +} + +// Given a resume function @f.resume(%f.frame* %frame), returns %f.frame type. +static Type *getFrameType(Function *Resume) { + auto *ArgType = Resume->getArgumentList().front().getType(); + return cast<PointerType>(ArgType)->getElementType(); +} + +// Finds first non alloca instruction in the entry block of a function. +static Instruction *getFirstNonAllocaInTheEntryBlock(Function *F) { + for (Instruction &I : F->getEntryBlock()) + if (!isa<AllocaInst>(&I)) + return &I; + llvm_unreachable("no terminator in the entry block"); +} + +// To elide heap allocations we need to suppress code blocks guarded by +// llvm.coro.alloc and llvm.coro.free instructions. +void Lowerer::elideHeapAllocations(Function *F, Type *FrameTy, AAResults &AA) { + LLVMContext &C = FrameTy->getContext(); + auto *InsertPt = + getFirstNonAllocaInTheEntryBlock(CoroIds.front()->getFunction()); + + // Replacing llvm.coro.alloc with false will suppress dynamic + // allocation as it is expected for the frontend to generate the code that + // looks like: + // id = coro.id(...) + // mem = coro.alloc(id) ? malloc(coro.size()) : 0; + // coro.begin(id, mem) + auto *False = ConstantInt::getFalse(C); + for (auto *CA : CoroAllocs) { + CA->replaceAllUsesWith(False); + CA->eraseFromParent(); + } + + // FIXME: Design how to transmit alignment information for every alloca that + // is spilled into the coroutine frame and recreate the alignment information + // here. Possibly we will need to do a mini SROA here and break the coroutine + // frame into individual AllocaInst recreating the original alignment. + auto *Frame = new AllocaInst(FrameTy, "", InsertPt); + auto *FrameVoidPtr = + new BitCastInst(Frame, Type::getInt8PtrTy(C), "vFrame", InsertPt); + + for (auto *CB : CoroBegins) { + CB->replaceAllUsesWith(FrameVoidPtr); + CB->eraseFromParent(); + } + + // Since now coroutine frame lives on the stack we need to make sure that + // any tail call referencing it, must be made non-tail call. + removeTailCallAttribute(Frame, AA); +} + +bool Lowerer::shouldElide() const { + // If no CoroAllocs, we cannot suppress allocation, so elision is not + // possible. + if (CoroAllocs.empty()) + return false; + + // Check that for every coro.begin there is a coro.destroy directly + // referencing the SSA value of that coro.begin. If the value escaped, then + // coro.destroy would have been referencing a memory location storing that + // value and not the virtual register. + + SmallPtrSet<CoroBeginInst *, 8> ReferencedCoroBegins; + + for (CoroSubFnInst *DA : DestroyAddr) { + if (auto *CB = dyn_cast<CoroBeginInst>(DA->getFrame())) + ReferencedCoroBegins.insert(CB); + else + return false; + } + + // If size of the set is the same as total number of CoroBegins, means we + // found a coro.free or coro.destroy mentioning a coro.begin and we can + // perform heap elision. + return ReferencedCoroBegins.size() == CoroBegins.size(); +} + +bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA) { + CoroBegins.clear(); + CoroAllocs.clear(); + CoroFrees.clear(); + ResumeAddr.clear(); + DestroyAddr.clear(); + + // Collect all coro.begin and coro.allocs associated with this coro.id. + for (User *U : CoroId->users()) { + if (auto *CB = dyn_cast<CoroBeginInst>(U)) + CoroBegins.push_back(CB); + else if (auto *CA = dyn_cast<CoroAllocInst>(U)) + CoroAllocs.push_back(CA); + else if (auto *CF = dyn_cast<CoroFreeInst>(U)) + CoroFrees.push_back(CF); + } + + // Collect all coro.subfn.addrs associated with coro.begin. + // Note, we only devirtualize the calls if their coro.subfn.addr refers to + // coro.begin directly. If we run into cases where this check is too + // conservative, we can consider relaxing the check. + for (CoroBeginInst *CB : CoroBegins) { + for (User *U : CB->users()) + if (auto *II = dyn_cast<CoroSubFnInst>(U)) + switch (II->getIndex()) { + case CoroSubFnInst::ResumeIndex: + ResumeAddr.push_back(II); + break; + case CoroSubFnInst::DestroyIndex: + DestroyAddr.push_back(II); + break; + default: + llvm_unreachable("unexpected coro.subfn.addr constant"); + } + } + + // PostSplit coro.id refers to an array of subfunctions in its Info + // argument. + ConstantArray *Resumers = CoroId->getInfo().Resumers; + assert(Resumers && "PostSplit coro.id Info argument must refer to an array" + "of coroutine subfunctions"); + auto *ResumeAddrConstant = + ConstantExpr::getExtractValue(Resumers, CoroSubFnInst::ResumeIndex); + + replaceWithConstant(ResumeAddrConstant, ResumeAddr); + + bool ShouldElide = shouldElide(); + + auto *DestroyAddrConstant = ConstantExpr::getExtractValue( + Resumers, + ShouldElide ? CoroSubFnInst::CleanupIndex : CoroSubFnInst::DestroyIndex); + + replaceWithConstant(DestroyAddrConstant, DestroyAddr); + + if (ShouldElide) { + auto *FrameTy = getFrameType(cast<Function>(ResumeAddrConstant)); + elideHeapAllocations(CoroId->getFunction(), FrameTy, AA); + coro::replaceCoroFree(CoroId, /*Elide=*/true); + } + + return true; +} + +// See if there are any coro.subfn.addr instructions referring to coro.devirt +// trigger, if so, replace them with a direct call to devirt trigger function. +static bool replaceDevirtTrigger(Function &F) { + SmallVector<CoroSubFnInst *, 1> DevirtAddr; + for (auto &I : instructions(F)) + if (auto *SubFn = dyn_cast<CoroSubFnInst>(&I)) + if (SubFn->getIndex() == CoroSubFnInst::RestartTrigger) + DevirtAddr.push_back(SubFn); + + if (DevirtAddr.empty()) + return false; + + Module &M = *F.getParent(); + Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN); + assert(DevirtFn && "coro.devirt.fn not found"); + replaceWithConstant(DevirtFn, DevirtAddr); + + return true; +} + +//===----------------------------------------------------------------------===// +// Top Level Driver +//===----------------------------------------------------------------------===// + +namespace { +struct CoroElide : FunctionPass { + static char ID; + CoroElide() : FunctionPass(ID) {} + + std::unique_ptr<Lowerer> L; + + bool doInitialization(Module &M) override { + if (coro::declaresIntrinsics(M, {"llvm.coro.id"})) + L = llvm::make_unique<Lowerer>(M); + return false; + } + + bool runOnFunction(Function &F) override { + if (!L) + return false; + + bool Changed = false; + + if (F.hasFnAttribute(CORO_PRESPLIT_ATTR)) + Changed = replaceDevirtTrigger(F); + + L->CoroIds.clear(); + + // Collect all PostSplit coro.ids. + for (auto &I : instructions(F)) + if (auto *CII = dyn_cast<CoroIdInst>(&I)) + if (CII->getInfo().isPostSplit()) + // If it is the coroutine itself, don't touch it. + if (CII->getCoroutine() != CII->getFunction()) + L->CoroIds.push_back(CII); + + // If we did not find any coro.id, there is nothing to do. + if (L->CoroIds.empty()) + return Changed; + + AAResults &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + + for (auto *CII : L->CoroIds) + Changed |= L->processCoroId(CII, AA); + + return Changed; + } + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AAResultsWrapperPass>(); + } +}; +} + +char CoroElide::ID = 0; +INITIALIZE_PASS_BEGIN( + CoroElide, "coro-elide", + "Coroutine frame allocation elision and indirect calls replacement", false, + false) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END( + CoroElide, "coro-elide", + "Coroutine frame allocation elision and indirect calls replacement", false, + false) + +Pass *llvm::createCoroElidePass() { return new CoroElide(); } diff --git a/contrib/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/contrib/llvm/lib/Transforms/Coroutines/CoroFrame.cpp new file mode 100644 index 0000000..bb28558a --- /dev/null +++ b/contrib/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -0,0 +1,727 @@ +//===- CoroFrame.cpp - Builds and manipulates coroutine frame -------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// This file contains classes used to discover if for a particular value +// there from sue to definition that crosses a suspend block. +// +// Using the information discovered we form a Coroutine Frame structure to +// contain those values. All uses of those values are replaced with appropriate +// GEP + load from the coroutine frame. At the point of the definition we spill +// the value into the coroutine frame. +// +// TODO: pack values tightly using liveness info. +//===----------------------------------------------------------------------===// + +#include "CoroInternal.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/circular_raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" + +using namespace llvm; + +// The "coro-suspend-crossing" flag is very noisy. There is another debug type, +// "coro-frame", which results in leaner debug spew. +#define DEBUG_TYPE "coro-suspend-crossing" + +enum { SmallVectorThreshold = 32 }; + +// Provides two way mapping between the blocks and numbers. +namespace { +class BlockToIndexMapping { + SmallVector<BasicBlock *, SmallVectorThreshold> V; + +public: + size_t size() const { return V.size(); } + + BlockToIndexMapping(Function &F) { + for (BasicBlock &BB : F) + V.push_back(&BB); + std::sort(V.begin(), V.end()); + } + + size_t blockToIndex(BasicBlock *BB) const { + auto *I = std::lower_bound(V.begin(), V.end(), BB); + assert(I != V.end() && *I == BB && "BasicBlockNumberng: Unknown block"); + return I - V.begin(); + } + + BasicBlock *indexToBlock(unsigned Index) const { return V[Index]; } +}; +} // end anonymous namespace + +// The SuspendCrossingInfo maintains data that allows to answer a question +// whether given two BasicBlocks A and B there is a path from A to B that +// passes through a suspend point. +// +// For every basic block 'i' it maintains a BlockData that consists of: +// Consumes: a bit vector which contains a set of indices of blocks that can +// reach block 'i' +// Kills: a bit vector which contains a set of indices of blocks that can +// reach block 'i', but one of the path will cross a suspend point +// Suspend: a boolean indicating whether block 'i' contains a suspend point. +// End: a boolean indicating whether block 'i' contains a coro.end intrinsic. +// +namespace { +struct SuspendCrossingInfo { + BlockToIndexMapping Mapping; + + struct BlockData { + BitVector Consumes; + BitVector Kills; + bool Suspend = false; + bool End = false; + }; + SmallVector<BlockData, SmallVectorThreshold> Block; + + iterator_range<succ_iterator> successors(BlockData const &BD) const { + BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]); + return llvm::successors(BB); + } + + BlockData &getBlockData(BasicBlock *BB) { + return Block[Mapping.blockToIndex(BB)]; + } + + void dump() const; + void dump(StringRef Label, BitVector const &BV) const; + + SuspendCrossingInfo(Function &F, coro::Shape &Shape); + + bool hasPathCrossingSuspendPoint(BasicBlock *DefBB, BasicBlock *UseBB) const { + size_t const DefIndex = Mapping.blockToIndex(DefBB); + size_t const UseIndex = Mapping.blockToIndex(UseBB); + + assert(Block[UseIndex].Consumes[DefIndex] && "use must consume def"); + bool const Result = Block[UseIndex].Kills[DefIndex]; + DEBUG(dbgs() << UseBB->getName() << " => " << DefBB->getName() + << " answer is " << Result << "\n"); + return Result; + } + + bool isDefinitionAcrossSuspend(BasicBlock *DefBB, User *U) const { + auto *I = cast<Instruction>(U); + + // We rewrote PHINodes, so that only the ones with exactly one incoming + // value need to be analyzed. + if (auto *PN = dyn_cast<PHINode>(I)) + if (PN->getNumIncomingValues() > 1) + return false; + + BasicBlock *UseBB = I->getParent(); + return hasPathCrossingSuspendPoint(DefBB, UseBB); + } + + bool isDefinitionAcrossSuspend(Argument &A, User *U) const { + return isDefinitionAcrossSuspend(&A.getParent()->getEntryBlock(), U); + } + + bool isDefinitionAcrossSuspend(Instruction &I, User *U) const { + return isDefinitionAcrossSuspend(I.getParent(), U); + } +}; +} // end anonymous namespace + +LLVM_DUMP_METHOD void SuspendCrossingInfo::dump(StringRef Label, + BitVector const &BV) const { + dbgs() << Label << ":"; + for (size_t I = 0, N = BV.size(); I < N; ++I) + if (BV[I]) + dbgs() << " " << Mapping.indexToBlock(I)->getName(); + dbgs() << "\n"; +} + +LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const { + for (size_t I = 0, N = Block.size(); I < N; ++I) { + BasicBlock *const B = Mapping.indexToBlock(I); + dbgs() << B->getName() << ":\n"; + dump(" Consumes", Block[I].Consumes); + dump(" Kills", Block[I].Kills); + } + dbgs() << "\n"; +} + +SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) + : Mapping(F) { + const size_t N = Mapping.size(); + Block.resize(N); + + // Initialize every block so that it consumes itself + for (size_t I = 0; I < N; ++I) { + auto &B = Block[I]; + B.Consumes.resize(N); + B.Kills.resize(N); + B.Consumes.set(I); + } + + // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as + // the code beyond coro.end is reachable during initial invocation of the + // coroutine. + for (auto *CE : Shape.CoroEnds) + getBlockData(CE->getParent()).End = true; + + // Mark all suspend blocks and indicate that they kill everything they + // consume. Note, that crossing coro.save also requires a spill, as any code + // between coro.save and coro.suspend may resume the coroutine and all of the + // state needs to be saved by that time. + auto markSuspendBlock = [&](IntrinsicInst* BarrierInst) { + BasicBlock *SuspendBlock = BarrierInst->getParent(); + auto &B = getBlockData(SuspendBlock); + B.Suspend = true; + B.Kills |= B.Consumes; + }; + for (CoroSuspendInst *CSI : Shape.CoroSuspends) { + markSuspendBlock(CSI); + markSuspendBlock(CSI->getCoroSave()); + } + + // Iterate propagating consumes and kills until they stop changing. + int Iteration = 0; + (void)Iteration; + + bool Changed; + do { + DEBUG(dbgs() << "iteration " << ++Iteration); + DEBUG(dbgs() << "==============\n"); + + Changed = false; + for (size_t I = 0; I < N; ++I) { + auto &B = Block[I]; + for (BasicBlock *SI : successors(B)) { + + auto SuccNo = Mapping.blockToIndex(SI); + + // Saved Consumes and Kills bitsets so that it is easy to see + // if anything changed after propagation. + auto &S = Block[SuccNo]; + auto SavedConsumes = S.Consumes; + auto SavedKills = S.Kills; + + // Propagate Kills and Consumes from block B into its successor S. + S.Consumes |= B.Consumes; + S.Kills |= B.Kills; + + // If block B is a suspend block, it should propagate kills into the + // its successor for every block B consumes. + if (B.Suspend) { + S.Kills |= B.Consumes; + } + if (S.Suspend) { + // If block S is a suspend block, it should kill all of the blocks it + // consumes. + S.Kills |= S.Consumes; + } else if (S.End) { + // If block S is an end block, it should not propagate kills as the + // blocks following coro.end() are reached during initial invocation + // of the coroutine while all the data are still available on the + // stack or in the registers. + S.Kills.reset(); + } else { + // This is reached when S block it not Suspend nor coro.end and it + // need to make sure that it is not in the kill set. + S.Kills.reset(SuccNo); + } + + // See if anything changed. + Changed |= (S.Kills != SavedKills) || (S.Consumes != SavedConsumes); + + if (S.Kills != SavedKills) { + DEBUG(dbgs() << "\nblock " << I << " follower " << SI->getName() + << "\n"); + DEBUG(dump("S.Kills", S.Kills)); + DEBUG(dump("SavedKills", SavedKills)); + } + if (S.Consumes != SavedConsumes) { + DEBUG(dbgs() << "\nblock " << I << " follower " << SI << "\n"); + DEBUG(dump("S.Consume", S.Consumes)); + DEBUG(dump("SavedCons", SavedConsumes)); + } + } + } + } while (Changed); + DEBUG(dump()); +} + +#undef DEBUG_TYPE // "coro-suspend-crossing" +#define DEBUG_TYPE "coro-frame" + +// We build up the list of spills for every case where a use is separated +// from the definition by a suspend point. + +struct Spill : std::pair<Value *, Instruction *> { + using base = std::pair<Value *, Instruction *>; + + Spill(Value *Def, User *U) : base(Def, cast<Instruction>(U)) {} + + Value *def() const { return first; } + Instruction *user() const { return second; } + BasicBlock *userBlock() const { return second->getParent(); } + + std::pair<Value *, BasicBlock *> getKey() const { + return {def(), userBlock()}; + } + + bool operator<(Spill const &rhs) const { return getKey() < rhs.getKey(); } +}; + +// Note that there may be more than one record with the same value of Def in +// the SpillInfo vector. +using SpillInfo = SmallVector<Spill, 8>; + +#ifndef NDEBUG +static void dump(StringRef Title, SpillInfo const &Spills) { + dbgs() << "------------- " << Title << "--------------\n"; + Value *CurrentValue = nullptr; + for (auto const &E : Spills) { + if (CurrentValue != E.def()) { + CurrentValue = E.def(); + CurrentValue->dump(); + } + dbgs() << " user: "; + E.user()->dump(); + } +} +#endif + +// Build a struct that will keep state for an active coroutine. +// struct f.frame { +// ResumeFnTy ResumeFnAddr; +// ResumeFnTy DestroyFnAddr; +// int ResumeIndex; +// ... promise (if present) ... +// ... spills ... +// }; +static StructType *buildFrameType(Function &F, coro::Shape &Shape, + SpillInfo &Spills) { + LLVMContext &C = F.getContext(); + SmallString<32> Name(F.getName()); + Name.append(".Frame"); + StructType *FrameTy = StructType::create(C, Name); + auto *FramePtrTy = FrameTy->getPointerTo(); + auto *FnTy = FunctionType::get(Type::getVoidTy(C), FramePtrTy, + /*IsVarArgs=*/false); + auto *FnPtrTy = FnTy->getPointerTo(); + + // Figure out how wide should be an integer type storing the suspend index. + unsigned IndexBits = std::max(1U, Log2_64_Ceil(Shape.CoroSuspends.size())); + Type *PromiseType = Shape.PromiseAlloca + ? Shape.PromiseAlloca->getType()->getElementType() + : Type::getInt1Ty(C); + SmallVector<Type *, 8> Types{FnPtrTy, FnPtrTy, PromiseType, + Type::getIntNTy(C, IndexBits)}; + Value *CurrentDef = nullptr; + + // Create an entry for every spilled value. + for (auto const &S : Spills) { + if (CurrentDef == S.def()) + continue; + + CurrentDef = S.def(); + // PromiseAlloca was already added to Types array earlier. + if (CurrentDef == Shape.PromiseAlloca) + continue; + + Type *Ty = nullptr; + if (auto *AI = dyn_cast<AllocaInst>(CurrentDef)) + Ty = AI->getAllocatedType(); + else + Ty = CurrentDef->getType(); + + Types.push_back(Ty); + } + FrameTy->setBody(Types); + + return FrameTy; +} + +// Replace all alloca and SSA values that are accessed across suspend points +// with GetElementPointer from coroutine frame + loads and stores. Create an +// AllocaSpillBB that will become the new entry block for the resume parts of +// the coroutine: +// +// %hdl = coro.begin(...) +// whatever +// +// becomes: +// +// %hdl = coro.begin(...) +// %FramePtr = bitcast i8* hdl to %f.frame* +// br label %AllocaSpillBB +// +// AllocaSpillBB: +// ; geps corresponding to allocas that were moved to coroutine frame +// br label PostSpill +// +// PostSpill: +// whatever +// +// +static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { + auto *CB = Shape.CoroBegin; + IRBuilder<> Builder(CB->getNextNode()); + PointerType *FramePtrTy = Shape.FrameTy->getPointerTo(); + auto *FramePtr = + cast<Instruction>(Builder.CreateBitCast(CB, FramePtrTy, "FramePtr")); + Type *FrameTy = FramePtrTy->getElementType(); + + Value *CurrentValue = nullptr; + BasicBlock *CurrentBlock = nullptr; + Value *CurrentReload = nullptr; + unsigned Index = coro::Shape::LastKnownField; + + // We need to keep track of any allocas that need "spilling" + // since they will live in the coroutine frame now, all access to them + // need to be changed, not just the access across suspend points + // we remember allocas and their indices to be handled once we processed + // all the spills. + SmallVector<std::pair<AllocaInst *, unsigned>, 4> Allocas; + // Promise alloca (if present) has a fixed field number (Shape::PromiseField) + if (Shape.PromiseAlloca) + Allocas.emplace_back(Shape.PromiseAlloca, coro::Shape::PromiseField); + + // Create a load instruction to reload the spilled value from the coroutine + // frame. + auto CreateReload = [&](Instruction *InsertBefore) { + Builder.SetInsertPoint(InsertBefore); + auto *G = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, Index, + CurrentValue->getName() + + Twine(".reload.addr")); + return isa<AllocaInst>(CurrentValue) + ? G + : Builder.CreateLoad(G, + CurrentValue->getName() + Twine(".reload")); + }; + + for (auto const &E : Spills) { + // If we have not seen the value, generate a spill. + if (CurrentValue != E.def()) { + CurrentValue = E.def(); + CurrentBlock = nullptr; + CurrentReload = nullptr; + + ++Index; + + if (auto *AI = dyn_cast<AllocaInst>(CurrentValue)) { + // Spilled AllocaInst will be replaced with GEP from the coroutine frame + // there is no spill required. + Allocas.emplace_back(AI, Index); + if (!AI->isStaticAlloca()) + report_fatal_error("Coroutines cannot handle non static allocas yet"); + } else { + // Otherwise, create a store instruction storing the value into the + // coroutine frame. For, argument, we will place the store instruction + // right after the coroutine frame pointer instruction, i.e. bitcase of + // coro.begin from i8* to %f.frame*. For all other values, the spill is + // placed immediately after the definition. + Builder.SetInsertPoint( + isa<Argument>(CurrentValue) + ? FramePtr->getNextNode() + : dyn_cast<Instruction>(E.def())->getNextNode()); + + auto *G = Builder.CreateConstInBoundsGEP2_32( + FrameTy, FramePtr, 0, Index, + CurrentValue->getName() + Twine(".spill.addr")); + Builder.CreateStore(CurrentValue, G); + } + } + + // If we have not seen the use block, generate a reload in it. + if (CurrentBlock != E.userBlock()) { + CurrentBlock = E.userBlock(); + CurrentReload = CreateReload(&*CurrentBlock->getFirstInsertionPt()); + } + + // If we have a single edge PHINode, remove it and replace it with a reload + // from the coroutine frame. (We already took care of multi edge PHINodes + // by rewriting them in the rewritePHIs function). + if (auto *PN = dyn_cast<PHINode>(E.user())) { + assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming " + "values in the PHINode"); + PN->replaceAllUsesWith(CurrentReload); + PN->eraseFromParent(); + continue; + } + + // Replace all uses of CurrentValue in the current instruction with reload. + E.user()->replaceUsesOfWith(CurrentValue, CurrentReload); + } + + BasicBlock *FramePtrBB = FramePtr->getParent(); + Shape.AllocaSpillBlock = + FramePtrBB->splitBasicBlock(FramePtr->getNextNode(), "AllocaSpillBB"); + Shape.AllocaSpillBlock->splitBasicBlock(&Shape.AllocaSpillBlock->front(), + "PostSpill"); + + Builder.SetInsertPoint(&Shape.AllocaSpillBlock->front()); + // If we found any allocas, replace all of their remaining uses with Geps. + for (auto &P : Allocas) { + auto *G = + Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, P.second); + // We are not using ReplaceInstWithInst(P.first, cast<Instruction>(G)) here, + // as we are changing location of the instruction. + G->takeName(P.first); + P.first->replaceAllUsesWith(G); + P.first->eraseFromParent(); + } + return FramePtr; +} + +static void rewritePHIs(BasicBlock &BB) { + // For every incoming edge we will create a block holding all + // incoming values in a single PHI nodes. + // + // loop: + // %n.val = phi i32[%n, %entry], [%inc, %loop] + // + // It will create: + // + // loop.from.entry: + // %n.loop.pre = phi i32 [%n, %entry] + // br %label loop + // loop.from.loop: + // %inc.loop.pre = phi i32 [%inc, %loop] + // br %label loop + // + // After this rewrite, further analysis will ignore any phi nodes with more + // than one incoming edge. + + // TODO: Simplify PHINodes in the basic block to remove duplicate + // predecessors. + + SmallVector<BasicBlock *, 8> Preds(pred_begin(&BB), pred_end(&BB)); + for (BasicBlock *Pred : Preds) { + auto *IncomingBB = SplitEdge(Pred, &BB); + IncomingBB->setName(BB.getName() + Twine(".from.") + Pred->getName()); + auto *PN = cast<PHINode>(&BB.front()); + do { + int Index = PN->getBasicBlockIndex(IncomingBB); + Value *V = PN->getIncomingValue(Index); + PHINode *InputV = PHINode::Create( + V->getType(), 1, V->getName() + Twine(".") + BB.getName(), + &IncomingBB->front()); + InputV->addIncoming(V, Pred); + PN->setIncomingValue(Index, InputV); + PN = dyn_cast<PHINode>(PN->getNextNode()); + } while (PN); + } +} + +static void rewritePHIs(Function &F) { + SmallVector<BasicBlock *, 8> WorkList; + + for (BasicBlock &BB : F) + if (auto *PN = dyn_cast<PHINode>(&BB.front())) + if (PN->getNumIncomingValues() > 1) + WorkList.push_back(&BB); + + for (BasicBlock *BB : WorkList) + rewritePHIs(*BB); +} + +// Check for instructions that we can recreate on resume as opposed to spill +// the result into a coroutine frame. +static bool materializable(Instruction &V) { + return isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) || + isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V); +} + +// Check for structural coroutine intrinsics that should not be spilled into +// the coroutine frame. +static bool isCoroutineStructureIntrinsic(Instruction &I) { + return isa<CoroIdInst>(&I) || isa<CoroBeginInst>(&I) || + isa<CoroSaveInst>(&I) || isa<CoroSuspendInst>(&I); +} + +// For every use of the value that is across suspend point, recreate that value +// after a suspend point. +static void rewriteMaterializableInstructions(IRBuilder<> &IRB, + SpillInfo const &Spills) { + BasicBlock *CurrentBlock = nullptr; + Instruction *CurrentMaterialization = nullptr; + Instruction *CurrentDef = nullptr; + + for (auto const &E : Spills) { + // If it is a new definition, update CurrentXXX variables. + if (CurrentDef != E.def()) { + CurrentDef = cast<Instruction>(E.def()); + CurrentBlock = nullptr; + CurrentMaterialization = nullptr; + } + + // If we have not seen this block, materialize the value. + if (CurrentBlock != E.userBlock()) { + CurrentBlock = E.userBlock(); + CurrentMaterialization = cast<Instruction>(CurrentDef)->clone(); + CurrentMaterialization->setName(CurrentDef->getName()); + CurrentMaterialization->insertBefore( + &*CurrentBlock->getFirstInsertionPt()); + } + + if (auto *PN = dyn_cast<PHINode>(E.user())) { + assert(PN->getNumIncomingValues() == 1 && "unexpected number of incoming " + "values in the PHINode"); + PN->replaceAllUsesWith(CurrentMaterialization); + PN->eraseFromParent(); + continue; + } + + // Replace all uses of CurrentDef in the current instruction with the + // CurrentMaterialization for the block. + E.user()->replaceUsesOfWith(CurrentDef, CurrentMaterialization); + } +} + +// Move early uses of spilled variable after CoroBegin. +// For example, if a parameter had address taken, we may end up with the code +// like: +// define @f(i32 %n) { +// %n.addr = alloca i32 +// store %n, %n.addr +// ... +// call @coro.begin +// we need to move the store after coro.begin +static void moveSpillUsesAfterCoroBegin(Function &F, SpillInfo const &Spills, + CoroBeginInst *CoroBegin) { + DominatorTree DT(F); + SmallVector<Instruction *, 8> NeedsMoving; + + Value *CurrentValue = nullptr; + + for (auto const &E : Spills) { + if (CurrentValue == E.def()) + continue; + + CurrentValue = E.def(); + + for (User *U : CurrentValue->users()) { + Instruction *I = cast<Instruction>(U); + if (!DT.dominates(CoroBegin, I)) { + // TODO: Make this more robust. Currently if we run into a situation + // where simple instruction move won't work we panic and + // report_fatal_error. + for (User *UI : I->users()) { + if (!DT.dominates(CoroBegin, cast<Instruction>(UI))) + report_fatal_error("cannot move instruction since its users are not" + " dominated by CoroBegin"); + } + + DEBUG(dbgs() << "will move: " << *I << "\n"); + NeedsMoving.push_back(I); + } + } + } + + Instruction *InsertPt = CoroBegin->getNextNode(); + for (Instruction *I : NeedsMoving) + I->moveBefore(InsertPt); +} + +// Splits the block at a particular instruction unless it is the first +// instruction in the block with a single predecessor. +static BasicBlock *splitBlockIfNotFirst(Instruction *I, const Twine &Name) { + auto *BB = I->getParent(); + if (&BB->front() == I) { + if (BB->getSinglePredecessor()) { + BB->setName(Name); + return BB; + } + } + return BB->splitBasicBlock(I, Name); +} + +// Split above and below a particular instruction so that it +// will be all alone by itself in a block. +static void splitAround(Instruction *I, const Twine &Name) { + splitBlockIfNotFirst(I, Name); + splitBlockIfNotFirst(I->getNextNode(), "After" + Name); +} + +void coro::buildCoroutineFrame(Function &F, Shape &Shape) { + // Lower coro.dbg.declare to coro.dbg.value, since we are going to rewrite + // access to local variables. + LowerDbgDeclare(F); + + Shape.PromiseAlloca = Shape.CoroBegin->getId()->getPromise(); + if (Shape.PromiseAlloca) { + Shape.CoroBegin->getId()->clearPromise(); + } + + // Make sure that all coro.save, coro.suspend and the fallthrough coro.end + // intrinsics are in their own blocks to simplify the logic of building up + // SuspendCrossing data. + for (CoroSuspendInst *CSI : Shape.CoroSuspends) { + splitAround(CSI->getCoroSave(), "CoroSave"); + splitAround(CSI, "CoroSuspend"); + } + + // Put fallthrough CoroEnd into its own block. Note: Shape::buildFrom places + // the fallthrough coro.end as the first element of CoroEnds array. + splitAround(Shape.CoroEnds.front(), "CoroEnd"); + + // Transforms multi-edge PHI Nodes, so that any value feeding into a PHI will + // never has its definition separated from the PHI by the suspend point. + rewritePHIs(F); + + // Build suspend crossing info. + SuspendCrossingInfo Checker(F, Shape); + + IRBuilder<> Builder(F.getContext()); + SpillInfo Spills; + + // See if there are materializable instructions across suspend points. + for (Instruction &I : instructions(F)) + if (materializable(I)) + for (User *U : I.users()) + if (Checker.isDefinitionAcrossSuspend(I, U)) + Spills.emplace_back(&I, U); + + // Rewrite materializable instructions to be materialized at the use point. + std::sort(Spills.begin(), Spills.end()); + DEBUG(dump("Materializations", Spills)); + rewriteMaterializableInstructions(Builder, Spills); + + // Collect the spills for arguments and other not-materializable values. + Spills.clear(); + for (Argument &A : F.getArgumentList()) + for (User *U : A.users()) + if (Checker.isDefinitionAcrossSuspend(A, U)) + Spills.emplace_back(&A, U); + + for (Instruction &I : instructions(F)) { + // Values returned from coroutine structure intrinsics should not be part + // of the Coroutine Frame. + if (isCoroutineStructureIntrinsic(I)) + continue; + // The Coroutine Promise always included into coroutine frame, no need to + // check for suspend crossing. + if (Shape.PromiseAlloca == &I) + continue; + + for (User *U : I.users()) + if (Checker.isDefinitionAcrossSuspend(I, U)) { + // We cannot spill a token. + if (I.getType()->isTokenTy()) + report_fatal_error( + "token definition is separated from the use by a suspend point"); + assert(!materializable(I) && + "rewriteMaterializable did not do its job"); + Spills.emplace_back(&I, U); + } + } + std::sort(Spills.begin(), Spills.end()); + DEBUG(dump("Spills", Spills)); + moveSpillUsesAfterCoroBegin(F, Spills, Shape.CoroBegin); + Shape.FrameTy = buildFrameType(F, Shape, Spills); + Shape.FramePtr = insertSpills(Spills, Shape); +} diff --git a/contrib/llvm/lib/Transforms/Coroutines/CoroInstr.h b/contrib/llvm/lib/Transforms/Coroutines/CoroInstr.h new file mode 100644 index 0000000..e03cef4 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Coroutines/CoroInstr.h @@ -0,0 +1,318 @@ +//===-- CoroInstr.h - Coroutine Intrinsics Instruction Wrappers -*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// This file defines classes that make it really easy to deal with intrinsic +// functions with the isa/dyncast family of functions. In particular, this +// allows you to do things like: +// +// if (auto *SF = dyn_cast<CoroSubFnInst>(Inst)) +// ... SF->getFrame() ... +// +// All intrinsic function calls are instances of the call instruction, so these +// are all subclasses of the CallInst class. Note that none of these classes +// has state or virtual methods, which is an important part of this gross/neat +// hack working. +// +// The helpful comment above is borrowed from llvm/IntrinsicInst.h, we keep +// coroutine intrinsic wrappers here since they are only used by the passes in +// the Coroutine library. +//===----------------------------------------------------------------------===// + +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IntrinsicInst.h" + +namespace llvm { + +/// This class represents the llvm.coro.subfn.addr instruction. +class LLVM_LIBRARY_VISIBILITY CoroSubFnInst : public IntrinsicInst { + enum { FrameArg, IndexArg }; + +public: + enum ResumeKind { + RestartTrigger = -1, + ResumeIndex, + DestroyIndex, + CleanupIndex, + IndexLast, + IndexFirst = RestartTrigger + }; + + Value *getFrame() const { return getArgOperand(FrameArg); } + ResumeKind getIndex() const { + int64_t Index = getRawIndex()->getValue().getSExtValue(); + assert(Index >= IndexFirst && Index < IndexLast && + "unexpected CoroSubFnInst index argument"); + return static_cast<ResumeKind>(Index); + } + + ConstantInt *getRawIndex() const { + return cast<ConstantInt>(getArgOperand(IndexArg)); + } + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_subfn_addr; + } + static inline bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +/// This represents the llvm.coro.alloc instruction. +class LLVM_LIBRARY_VISIBILITY CoroAllocInst : public IntrinsicInst { +public: + // Methods to support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_alloc; + } + static inline bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +/// This represents the llvm.coro.alloc instruction. +class LLVM_LIBRARY_VISIBILITY CoroIdInst : public IntrinsicInst { + enum { AlignArg, PromiseArg, CoroutineArg, InfoArg }; + +public: + CoroAllocInst *getCoroAlloc() { + for (User *U : users()) + if (auto *CA = dyn_cast<CoroAllocInst>(U)) + return CA; + return nullptr; + } + + IntrinsicInst *getCoroBegin() { + for (User *U : users()) + if (auto *II = dyn_cast<IntrinsicInst>(U)) + if (II->getIntrinsicID() == Intrinsic::coro_begin) + return II; + llvm_unreachable("no coro.begin associated with coro.id"); + } + + AllocaInst *getPromise() const { + Value *Arg = getArgOperand(PromiseArg); + return isa<ConstantPointerNull>(Arg) + ? nullptr + : cast<AllocaInst>(Arg->stripPointerCasts()); + } + + void clearPromise() { + Value *Arg = getArgOperand(PromiseArg); + setArgOperand(PromiseArg, + ConstantPointerNull::get(Type::getInt8PtrTy(getContext()))); + if (isa<AllocaInst>(Arg)) + return; + assert((isa<BitCastInst>(Arg) || isa<GetElementPtrInst>(Arg)) && + "unexpected instruction designating the promise"); + // TODO: Add a check that any remaining users of Inst are after coro.begin + // or add code to move the users after coro.begin. + auto *Inst = cast<Instruction>(Arg); + if (Inst->use_empty()) { + Inst->eraseFromParent(); + return; + } + Inst->moveBefore(getCoroBegin()->getNextNode()); + } + + // Info argument of coro.id is + // fresh out of the frontend: null ; + // outlined : {Init, Return, Susp1, Susp2, ...} ; + // postsplit : [resume, destroy, cleanup] ; + // + // If parts of the coroutine were outlined to protect against undesirable + // code motion, these functions will be stored in a struct literal referred to + // by the Info parameter. Note: this is only needed before coroutine is split. + // + // After coroutine is split, resume functions are stored in an array + // referred to by this parameter. + + struct Info { + ConstantStruct *OutlinedParts = nullptr; + ConstantArray *Resumers = nullptr; + + bool hasOutlinedParts() const { return OutlinedParts != nullptr; } + bool isPostSplit() const { return Resumers != nullptr; } + bool isPreSplit() const { return !isPostSplit(); } + }; + Info getInfo() const { + Info Result; + auto *GV = dyn_cast<GlobalVariable>(getRawInfo()); + if (!GV) + return Result; + + assert(GV->isConstant() && GV->hasDefinitiveInitializer()); + Constant *Initializer = GV->getInitializer(); + if ((Result.OutlinedParts = dyn_cast<ConstantStruct>(Initializer))) + return Result; + + Result.Resumers = cast<ConstantArray>(Initializer); + return Result; + } + Constant *getRawInfo() const { + return cast<Constant>(getArgOperand(InfoArg)->stripPointerCasts()); + } + + void setInfo(Constant *C) { setArgOperand(InfoArg, C); } + + Function *getCoroutine() const { + return cast<Function>(getArgOperand(CoroutineArg)->stripPointerCasts()); + } + void setCoroutineSelf() { + assert(isa<ConstantPointerNull>(getArgOperand(CoroutineArg)) && + "Coroutine argument is already assigned"); + auto *const Int8PtrTy = Type::getInt8PtrTy(getContext()); + setArgOperand(CoroutineArg, + ConstantExpr::getBitCast(getFunction(), Int8PtrTy)); + } + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_id; + } + static inline bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +/// This represents the llvm.coro.frame instruction. +class LLVM_LIBRARY_VISIBILITY CoroFrameInst : public IntrinsicInst { +public: + // Methods to support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_frame; + } + static inline bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +/// This represents the llvm.coro.free instruction. +class LLVM_LIBRARY_VISIBILITY CoroFreeInst : public IntrinsicInst { + enum { IdArg, FrameArg }; + +public: + Value *getFrame() const { return getArgOperand(FrameArg); } + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_free; + } + static inline bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +/// This class represents the llvm.coro.begin instruction. +class LLVM_LIBRARY_VISIBILITY CoroBeginInst : public IntrinsicInst { + enum { IdArg, MemArg }; + +public: + CoroIdInst *getId() const { return cast<CoroIdInst>(getArgOperand(IdArg)); } + + Value *getMem() const { return getArgOperand(MemArg); } + + // Methods for support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_begin; + } + static inline bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +/// This represents the llvm.coro.save instruction. +class LLVM_LIBRARY_VISIBILITY CoroSaveInst : public IntrinsicInst { +public: + // Methods to support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_save; + } + static inline bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +/// This represents the llvm.coro.promise instruction. +class LLVM_LIBRARY_VISIBILITY CoroPromiseInst : public IntrinsicInst { + enum { FrameArg, AlignArg, FromArg }; + +public: + bool isFromPromise() const { + return cast<Constant>(getArgOperand(FromArg))->isOneValue(); + } + unsigned getAlignment() const { + return cast<ConstantInt>(getArgOperand(AlignArg))->getZExtValue(); + } + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_promise; + } + static inline bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +/// This represents the llvm.coro.suspend instruction. +class LLVM_LIBRARY_VISIBILITY CoroSuspendInst : public IntrinsicInst { + enum { SaveArg, FinalArg }; + +public: + CoroSaveInst *getCoroSave() const { + Value *Arg = getArgOperand(SaveArg); + if (auto *SI = dyn_cast<CoroSaveInst>(Arg)) + return SI; + assert(isa<ConstantTokenNone>(Arg)); + return nullptr; + } + bool isFinal() const { + return cast<Constant>(getArgOperand(FinalArg))->isOneValue(); + } + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_suspend; + } + static inline bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +/// This represents the llvm.coro.size instruction. +class LLVM_LIBRARY_VISIBILITY CoroSizeInst : public IntrinsicInst { +public: + // Methods to support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_size; + } + static inline bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +/// This represents the llvm.coro.end instruction. +class LLVM_LIBRARY_VISIBILITY CoroEndInst : public IntrinsicInst { + enum { FrameArg, UnwindArg }; + +public: + bool isFallthrough() const { return !isUnwind(); } + bool isUnwind() const { + return cast<Constant>(getArgOperand(UnwindArg))->isOneValue(); + } + + // Methods to support type inquiry through isa, cast, and dyn_cast: + static inline bool classof(const IntrinsicInst *I) { + return I->getIntrinsicID() == Intrinsic::coro_end; + } + static inline bool classof(const Value *V) { + return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V)); + } +}; + +} // End namespace llvm. diff --git a/contrib/llvm/lib/Transforms/Coroutines/CoroInternal.h b/contrib/llvm/lib/Transforms/Coroutines/CoroInternal.h new file mode 100644 index 0000000..1eac88d --- /dev/null +++ b/contrib/llvm/lib/Transforms/Coroutines/CoroInternal.h @@ -0,0 +1,107 @@ +//===- CoroInternal.h - Internal Coroutine interfaces ---------*- C++ -*---===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// Common definitions/declarations used internally by coroutine lowering passes. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TRANSFORMS_COROUTINES_COROINTERNAL_H +#define LLVM_LIB_TRANSFORMS_COROUTINES_COROINTERNAL_H + +#include "CoroInstr.h" +#include "llvm/Transforms/Coroutines.h" + +namespace llvm { + +class CallGraph; +class CallGraphSCC; +class PassRegistry; + +void initializeCoroEarlyPass(PassRegistry &); +void initializeCoroSplitPass(PassRegistry &); +void initializeCoroElidePass(PassRegistry &); +void initializeCoroCleanupPass(PassRegistry &); + +// CoroEarly pass marks every function that has coro.begin with a string +// attribute "coroutine.presplit"="0". CoroSplit pass processes the coroutine +// twice. First, it lets it go through complete IPO optimization pipeline as a +// single function. It forces restart of the pipeline by inserting an indirect +// call to an empty function "coro.devirt.trigger" which is devirtualized by +// CoroElide pass that triggers a restart of the pipeline by CGPassManager. +// When CoroSplit pass sees the same coroutine the second time, it splits it up, +// adds coroutine subfunctions to the SCC to be processed by IPO pipeline. + +#define CORO_PRESPLIT_ATTR "coroutine.presplit" +#define UNPREPARED_FOR_SPLIT "0" +#define PREPARED_FOR_SPLIT "1" + +#define CORO_DEVIRT_TRIGGER_FN "coro.devirt.trigger" + +namespace coro { + +bool declaresIntrinsics(Module &M, std::initializer_list<StringRef>); +void replaceAllCoroAllocs(CoroBeginInst *CB, bool Replacement); +void replaceAllCoroFrees(CoroBeginInst *CB, Value *Replacement); +void replaceCoroFree(CoroIdInst *CoroId, bool Elide); +void updateCallGraph(Function &Caller, ArrayRef<Function *> Funcs, + CallGraph &CG, CallGraphSCC &SCC); + +// Keeps data and helper functions for lowering coroutine intrinsics. +struct LowererBase { + Module &TheModule; + LLVMContext &Context; + PointerType *const Int8Ptr; + FunctionType *const ResumeFnType; + ConstantPointerNull *const NullPtr; + + LowererBase(Module &M); + Value *makeSubFnCall(Value *Arg, int Index, Instruction *InsertPt); +}; + +// Holds structural Coroutine Intrinsics for a particular function and other +// values used during CoroSplit pass. +struct LLVM_LIBRARY_VISIBILITY Shape { + CoroBeginInst *CoroBegin; + SmallVector<CoroEndInst *, 4> CoroEnds; + SmallVector<CoroSizeInst *, 2> CoroSizes; + SmallVector<CoroSuspendInst *, 4> CoroSuspends; + + // Field Indexes for known coroutine frame fields. + enum { + ResumeField, + DestroyField, + PromiseField, + IndexField, + LastKnownField = IndexField + }; + + StructType *FrameTy; + Instruction *FramePtr; + BasicBlock *AllocaSpillBlock; + SwitchInst *ResumeSwitch; + AllocaInst *PromiseAlloca; + bool HasFinalSuspend; + + IntegerType *getIndexType() const { + assert(FrameTy && "frame type not assigned"); + return cast<IntegerType>(FrameTy->getElementType(IndexField)); + } + ConstantInt *getIndex(uint64_t Value) const { + return ConstantInt::get(getIndexType(), Value); + } + + Shape() = default; + explicit Shape(Function &F) { buildFrom(F); } + void buildFrom(Function &F); +}; + +void buildCoroutineFrame(Function &F, Shape &Shape); + +} // End namespace coro. +} // End namespace llvm + +#endif diff --git a/contrib/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/contrib/llvm/lib/Transforms/Coroutines/CoroSplit.cpp new file mode 100644 index 0000000..7a3f4f6 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -0,0 +1,640 @@ +//===- CoroSplit.cpp - Converts a coroutine into a state machine ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// This pass builds the coroutine frame and outlines resume and destroy parts +// of the coroutine into separate functions. +// +// We present a coroutine to an LLVM as an ordinary function with suspension +// points marked up with intrinsics. We let the optimizer party on the coroutine +// as a single function for as long as possible. Shortly before the coroutine is +// eligible to be inlined into its callers, we split up the coroutine into parts +// corresponding to an initial, resume and destroy invocations of the coroutine, +// add them to the current SCC and restart the IPO pipeline to optimize the +// coroutine subfunctions we extracted before proceeding to the caller of the +// coroutine. +//===----------------------------------------------------------------------===// + +#include "CoroInternal.h" +#include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ValueMapper.h" + +using namespace llvm; + +#define DEBUG_TYPE "coro-split" + +// Create an entry block for a resume function with a switch that will jump to +// suspend points. +static BasicBlock *createResumeEntryBlock(Function &F, coro::Shape &Shape) { + LLVMContext &C = F.getContext(); + + // resume.entry: + // %index.addr = getelementptr inbounds %f.Frame, %f.Frame* %FramePtr, i32 0, + // i32 2 + // % index = load i32, i32* %index.addr + // switch i32 %index, label %unreachable [ + // i32 0, label %resume.0 + // i32 1, label %resume.1 + // ... + // ] + + auto *NewEntry = BasicBlock::Create(C, "resume.entry", &F); + auto *UnreachBB = BasicBlock::Create(C, "unreachable", &F); + + IRBuilder<> Builder(NewEntry); + auto *FramePtr = Shape.FramePtr; + auto *FrameTy = Shape.FrameTy; + auto *GepIndex = Builder.CreateConstInBoundsGEP2_32( + FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr"); + auto *Index = Builder.CreateLoad(GepIndex, "index"); + auto *Switch = + Builder.CreateSwitch(Index, UnreachBB, Shape.CoroSuspends.size()); + Shape.ResumeSwitch = Switch; + + size_t SuspendIndex = 0; + for (CoroSuspendInst *S : Shape.CoroSuspends) { + ConstantInt *IndexVal = Shape.getIndex(SuspendIndex); + + // Replace CoroSave with a store to Index: + // %index.addr = getelementptr %f.frame... (index field number) + // store i32 0, i32* %index.addr1 + auto *Save = S->getCoroSave(); + Builder.SetInsertPoint(Save); + if (S->isFinal()) { + // Final suspend point is represented by storing zero in ResumeFnAddr. + auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(FrameTy, FramePtr, 0, + 0, "ResumeFn.addr"); + auto *NullPtr = ConstantPointerNull::get(cast<PointerType>( + cast<PointerType>(GepIndex->getType())->getElementType())); + Builder.CreateStore(NullPtr, GepIndex); + } else { + auto *GepIndex = Builder.CreateConstInBoundsGEP2_32( + FrameTy, FramePtr, 0, coro::Shape::IndexField, "index.addr"); + Builder.CreateStore(IndexVal, GepIndex); + } + Save->replaceAllUsesWith(ConstantTokenNone::get(C)); + Save->eraseFromParent(); + + // Split block before and after coro.suspend and add a jump from an entry + // switch: + // + // whateverBB: + // whatever + // %0 = call i8 @llvm.coro.suspend(token none, i1 false) + // switch i8 %0, label %suspend[i8 0, label %resume + // i8 1, label %cleanup] + // becomes: + // + // whateverBB: + // whatever + // br label %resume.0.landing + // + // resume.0: ; <--- jump from the switch in the resume.entry + // %0 = tail call i8 @llvm.coro.suspend(token none, i1 false) + // br label %resume.0.landing + // + // resume.0.landing: + // %1 = phi i8[-1, %whateverBB], [%0, %resume.0] + // switch i8 % 1, label %suspend [i8 0, label %resume + // i8 1, label %cleanup] + + auto *SuspendBB = S->getParent(); + auto *ResumeBB = + SuspendBB->splitBasicBlock(S, "resume." + Twine(SuspendIndex)); + auto *LandingBB = ResumeBB->splitBasicBlock( + S->getNextNode(), ResumeBB->getName() + Twine(".landing")); + Switch->addCase(IndexVal, ResumeBB); + + cast<BranchInst>(SuspendBB->getTerminator())->setSuccessor(0, LandingBB); + auto *PN = PHINode::Create(Builder.getInt8Ty(), 2, "", &LandingBB->front()); + S->replaceAllUsesWith(PN); + PN->addIncoming(Builder.getInt8(-1), SuspendBB); + PN->addIncoming(S, ResumeBB); + + ++SuspendIndex; + } + + Builder.SetInsertPoint(UnreachBB); + Builder.CreateUnreachable(); + + return NewEntry; +} + +// In Resumers, we replace fallthrough coro.end with ret void and delete the +// rest of the block. +static void replaceFallthroughCoroEnd(IntrinsicInst *End, + ValueToValueMapTy &VMap) { + auto *NewE = cast<IntrinsicInst>(VMap[End]); + ReturnInst::Create(NewE->getContext(), nullptr, NewE); + + // Remove the rest of the block, by splitting it into an unreachable block. + auto *BB = NewE->getParent(); + BB->splitBasicBlock(NewE); + BB->getTerminator()->eraseFromParent(); +} + +// Rewrite final suspend point handling. We do not use suspend index to +// represent the final suspend point. Instead we zero-out ResumeFnAddr in the +// coroutine frame, since it is undefined behavior to resume a coroutine +// suspended at the final suspend point. Thus, in the resume function, we can +// simply remove the last case (when coro::Shape is built, the final suspend +// point (if present) is always the last element of CoroSuspends array). +// In the destroy function, we add a code sequence to check if ResumeFnAddress +// is Null, and if so, jump to the appropriate label to handle cleanup from the +// final suspend point. +static void handleFinalSuspend(IRBuilder<> &Builder, Value *FramePtr, + coro::Shape &Shape, SwitchInst *Switch, + bool IsDestroy) { + assert(Shape.HasFinalSuspend); + auto FinalCase = --Switch->case_end(); + BasicBlock *ResumeBB = FinalCase.getCaseSuccessor(); + Switch->removeCase(FinalCase); + if (IsDestroy) { + BasicBlock *OldSwitchBB = Switch->getParent(); + auto *NewSwitchBB = OldSwitchBB->splitBasicBlock(Switch, "Switch"); + Builder.SetInsertPoint(OldSwitchBB->getTerminator()); + auto *GepIndex = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, FramePtr, + 0, 0, "ResumeFn.addr"); + auto *Load = Builder.CreateLoad(GepIndex); + auto *NullPtr = + ConstantPointerNull::get(cast<PointerType>(Load->getType())); + auto *Cond = Builder.CreateICmpEQ(Load, NullPtr); + Builder.CreateCondBr(Cond, ResumeBB, NewSwitchBB); + OldSwitchBB->getTerminator()->eraseFromParent(); + } +} + +// Create a resume clone by cloning the body of the original function, setting +// new entry block and replacing coro.suspend an appropriate value to force +// resume or cleanup pass for every suspend point. +static Function *createClone(Function &F, Twine Suffix, coro::Shape &Shape, + BasicBlock *ResumeEntry, int8_t FnIndex) { + Module *M = F.getParent(); + auto *FrameTy = Shape.FrameTy; + auto *FnPtrTy = cast<PointerType>(FrameTy->getElementType(0)); + auto *FnTy = cast<FunctionType>(FnPtrTy->getElementType()); + + Function *NewF = + Function::Create(FnTy, GlobalValue::LinkageTypes::InternalLinkage, + F.getName() + Suffix, M); + NewF->addAttribute(1, Attribute::NonNull); + NewF->addAttribute(1, Attribute::NoAlias); + + ValueToValueMapTy VMap; + // Replace all args with undefs. The buildCoroutineFrame algorithm already + // rewritten access to the args that occurs after suspend points with loads + // and stores to/from the coroutine frame. + for (Argument &A : F.getArgumentList()) + VMap[&A] = UndefValue::get(A.getType()); + + SmallVector<ReturnInst *, 4> Returns; + + if (DISubprogram *SP = F.getSubprogram()) { + // If we have debug info, add mapping for the metadata nodes that should not + // be cloned by CloneFunctionInfo. + auto &MD = VMap.MD(); + MD[SP->getUnit()].reset(SP->getUnit()); + MD[SP->getType()].reset(SP->getType()); + MD[SP->getFile()].reset(SP->getFile()); + } + CloneFunctionInto(NewF, &F, VMap, /*ModuleLevelChanges=*/true, Returns); + + // Remove old returns. + for (ReturnInst *Return : Returns) + changeToUnreachable(Return, /*UseLLVMTrap=*/false); + + // Remove old return attributes. + NewF->removeAttributes( + AttributeSet::ReturnIndex, + AttributeSet::get( + NewF->getContext(), AttributeSet::ReturnIndex, + AttributeFuncs::typeIncompatible(NewF->getReturnType()))); + + // Make AllocaSpillBlock the new entry block. + auto *SwitchBB = cast<BasicBlock>(VMap[ResumeEntry]); + auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]); + Entry->moveBefore(&NewF->getEntryBlock()); + Entry->getTerminator()->eraseFromParent(); + BranchInst::Create(SwitchBB, Entry); + Entry->setName("entry" + Suffix); + + // Clear all predecessors of the new entry block. + auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]); + Entry->replaceAllUsesWith(Switch->getDefaultDest()); + + IRBuilder<> Builder(&NewF->getEntryBlock().front()); + + // Remap frame pointer. + Argument *NewFramePtr = &NewF->getArgumentList().front(); + Value *OldFramePtr = cast<Value>(VMap[Shape.FramePtr]); + NewFramePtr->takeName(OldFramePtr); + OldFramePtr->replaceAllUsesWith(NewFramePtr); + + // Remap vFrame pointer. + auto *NewVFrame = Builder.CreateBitCast( + NewFramePtr, Type::getInt8PtrTy(Builder.getContext()), "vFrame"); + Value *OldVFrame = cast<Value>(VMap[Shape.CoroBegin]); + OldVFrame->replaceAllUsesWith(NewVFrame); + + // Rewrite final suspend handling as it is not done via switch (allows to + // remove final case from the switch, since it is undefined behavior to resume + // the coroutine suspended at the final suspend point. + if (Shape.HasFinalSuspend) { + auto *Switch = cast<SwitchInst>(VMap[Shape.ResumeSwitch]); + bool IsDestroy = FnIndex != 0; + handleFinalSuspend(Builder, NewFramePtr, Shape, Switch, IsDestroy); + } + + // Replace coro suspend with the appropriate resume index. + // Replacing coro.suspend with (0) will result in control flow proceeding to + // a resume label associated with a suspend point, replacing it with (1) will + // result in control flow proceeding to a cleanup label associated with this + // suspend point. + auto *NewValue = Builder.getInt8(FnIndex ? 1 : 0); + for (CoroSuspendInst *CS : Shape.CoroSuspends) { + auto *MappedCS = cast<CoroSuspendInst>(VMap[CS]); + MappedCS->replaceAllUsesWith(NewValue); + MappedCS->eraseFromParent(); + } + + // Remove coro.end intrinsics. + replaceFallthroughCoroEnd(Shape.CoroEnds.front(), VMap); + // FIXME: coming in upcoming patches: + // replaceUnwindCoroEnds(Shape.CoroEnds, VMap); + + // Eliminate coro.free from the clones, replacing it with 'null' in cleanup, + // to suppress deallocation code. + coro::replaceCoroFree(cast<CoroIdInst>(VMap[Shape.CoroBegin->getId()]), + /*Elide=*/FnIndex == 2); + + NewF->setCallingConv(CallingConv::Fast); + + return NewF; +} + +static void removeCoroEnds(coro::Shape &Shape) { + for (CoroEndInst *CE : Shape.CoroEnds) + CE->eraseFromParent(); +} + +static void replaceFrameSize(coro::Shape &Shape) { + if (Shape.CoroSizes.empty()) + return; + + // In the same function all coro.sizes should have the same result type. + auto *SizeIntrin = Shape.CoroSizes.back(); + Module *M = SizeIntrin->getModule(); + const DataLayout &DL = M->getDataLayout(); + auto Size = DL.getTypeAllocSize(Shape.FrameTy); + auto *SizeConstant = ConstantInt::get(SizeIntrin->getType(), Size); + + for (CoroSizeInst *CS : Shape.CoroSizes) { + CS->replaceAllUsesWith(SizeConstant); + CS->eraseFromParent(); + } +} + +// Create a global constant array containing pointers to functions provided and +// set Info parameter of CoroBegin to point at this constant. Example: +// +// @f.resumers = internal constant [2 x void(%f.frame*)*] +// [void(%f.frame*)* @f.resume, void(%f.frame*)* @f.destroy] +// define void @f() { +// ... +// call i8* @llvm.coro.begin(i8* null, i32 0, i8* null, +// i8* bitcast([2 x void(%f.frame*)*] * @f.resumers to i8*)) +// +// Assumes that all the functions have the same signature. +static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin, + std::initializer_list<Function *> Fns) { + + SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end()); + assert(!Args.empty()); + Function *Part = *Fns.begin(); + Module *M = Part->getParent(); + auto *ArrTy = ArrayType::get(Part->getType(), Args.size()); + + auto *ConstVal = ConstantArray::get(ArrTy, Args); + auto *GV = new GlobalVariable(*M, ConstVal->getType(), /*isConstant=*/true, + GlobalVariable::PrivateLinkage, ConstVal, + F.getName() + Twine(".resumers")); + + // Update coro.begin instruction to refer to this constant. + LLVMContext &C = F.getContext(); + auto *BC = ConstantExpr::getPointerCast(GV, Type::getInt8PtrTy(C)); + CoroBegin->getId()->setInfo(BC); +} + +// Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame. +static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn, + Function *DestroyFn, Function *CleanupFn) { + + IRBuilder<> Builder(Shape.FramePtr->getNextNode()); + auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32( + Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField, + "resume.addr"); + Builder.CreateStore(ResumeFn, ResumeAddr); + + Value *DestroyOrCleanupFn = DestroyFn; + + CoroIdInst *CoroId = Shape.CoroBegin->getId(); + if (CoroAllocInst *CA = CoroId->getCoroAlloc()) { + // If there is a CoroAlloc and it returns false (meaning we elide the + // allocation, use CleanupFn instead of DestroyFn). + DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn); + } + + auto *DestroyAddr = Builder.CreateConstInBoundsGEP2_32( + Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::DestroyField, + "destroy.addr"); + Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr); +} + +static void postSplitCleanup(Function &F) { + removeUnreachableBlocks(F); + llvm::legacy::FunctionPassManager FPM(F.getParent()); + + FPM.add(createVerifierPass()); + FPM.add(createSCCPPass()); + FPM.add(createCFGSimplificationPass()); + FPM.add(createEarlyCSEPass()); + FPM.add(createCFGSimplificationPass()); + + FPM.doInitialization(); + FPM.run(F); + FPM.doFinalization(); +} + +// Coroutine has no suspend points. Remove heap allocation for the coroutine +// frame if possible. +static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) { + auto *CoroId = CoroBegin->getId(); + auto *AllocInst = CoroId->getCoroAlloc(); + coro::replaceCoroFree(CoroId, /*Elide=*/AllocInst != nullptr); + if (AllocInst) { + IRBuilder<> Builder(AllocInst); + // FIXME: Need to handle overaligned members. + auto *Frame = Builder.CreateAlloca(FrameTy); + auto *VFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy()); + AllocInst->replaceAllUsesWith(Builder.getFalse()); + AllocInst->eraseFromParent(); + CoroBegin->replaceAllUsesWith(VFrame); + } else { + CoroBegin->replaceAllUsesWith(CoroBegin->getMem()); + } + CoroBegin->eraseFromParent(); +} + +// look for a very simple pattern +// coro.save +// no other calls +// resume or destroy call +// coro.suspend +// +// If there are other calls between coro.save and coro.suspend, they can +// potentially resume or destroy the coroutine, so it is unsafe to eliminate a +// suspend point. +static bool simplifySuspendPoint(CoroSuspendInst *Suspend, + CoroBeginInst *CoroBegin) { + auto *Save = Suspend->getCoroSave(); + auto *BB = Suspend->getParent(); + if (BB != Save->getParent()) + return false; + + CallSite SingleCallSite; + + // Check that we have only one CallSite. + for (Instruction *I = Save->getNextNode(); I != Suspend; + I = I->getNextNode()) { + if (isa<CoroFrameInst>(I)) + continue; + if (isa<CoroSubFnInst>(I)) + continue; + if (CallSite CS = CallSite(I)) { + if (SingleCallSite) + return false; + else + SingleCallSite = CS; + } + } + auto *CallInstr = SingleCallSite.getInstruction(); + if (!CallInstr) + return false; + + auto *Callee = SingleCallSite.getCalledValue()->stripPointerCasts(); + + // See if the callsite is for resumption or destruction of the coroutine. + auto *SubFn = dyn_cast<CoroSubFnInst>(Callee); + if (!SubFn) + return false; + + // Does not refer to the current coroutine, we cannot do anything with it. + if (SubFn->getFrame() != CoroBegin) + return false; + + // Replace llvm.coro.suspend with the value that results in resumption over + // the resume or cleanup path. + Suspend->replaceAllUsesWith(SubFn->getRawIndex()); + Suspend->eraseFromParent(); + Save->eraseFromParent(); + + // No longer need a call to coro.resume or coro.destroy. + CallInstr->eraseFromParent(); + + if (SubFn->user_empty()) + SubFn->eraseFromParent(); + + return true; +} + +// Remove suspend points that are simplified. +static void simplifySuspendPoints(coro::Shape &Shape) { + auto &S = Shape.CoroSuspends; + size_t I = 0, N = S.size(); + if (N == 0) + return; + for (;;) { + if (simplifySuspendPoint(S[I], Shape.CoroBegin)) { + if (--N == I) + break; + std::swap(S[I], S[N]); + continue; + } + if (++I == N) + break; + } + S.resize(N); +} + +static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) { + coro::Shape Shape(F); + if (!Shape.CoroBegin) + return; + + simplifySuspendPoints(Shape); + buildCoroutineFrame(F, Shape); + replaceFrameSize(Shape); + + // If there are no suspend points, no split required, just remove + // the allocation and deallocation blocks, they are not needed. + if (Shape.CoroSuspends.empty()) { + handleNoSuspendCoroutine(Shape.CoroBegin, Shape.FrameTy); + removeCoroEnds(Shape); + postSplitCleanup(F); + coro::updateCallGraph(F, {}, CG, SCC); + return; + } + + auto *ResumeEntry = createResumeEntryBlock(F, Shape); + auto ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0); + auto DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1); + auto CleanupClone = createClone(F, ".cleanup", Shape, ResumeEntry, 2); + + // We no longer need coro.end in F. + removeCoroEnds(Shape); + + postSplitCleanup(F); + postSplitCleanup(*ResumeClone); + postSplitCleanup(*DestroyClone); + postSplitCleanup(*CleanupClone); + + // Store addresses resume/destroy/cleanup functions in the coroutine frame. + updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone); + + // Create a constant array referring to resume/destroy/clone functions pointed + // by the last argument of @llvm.coro.info, so that CoroElide pass can + // determined correct function to call. + setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone, CleanupClone}); + + // Update call graph and add the functions we created to the SCC. + coro::updateCallGraph(F, {ResumeClone, DestroyClone, CleanupClone}, CG, SCC); +} + +// When we see the coroutine the first time, we insert an indirect call to a +// devirt trigger function and mark the coroutine that it is now ready for +// split. +static void prepareForSplit(Function &F, CallGraph &CG) { + Module &M = *F.getParent(); +#ifndef NDEBUG + Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN); + assert(DevirtFn && "coro.devirt.trigger function not found"); +#endif + + F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT); + + // Insert an indirect call sequence that will be devirtualized by CoroElide + // pass: + // %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1) + // %1 = bitcast i8* %0 to void(i8*)* + // call void %1(i8* null) + coro::LowererBase Lowerer(M); + Instruction *InsertPt = F.getEntryBlock().getTerminator(); + auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(F.getContext())); + auto *DevirtFnAddr = + Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt); + auto *IndirectCall = CallInst::Create(DevirtFnAddr, Null, "", InsertPt); + + // Update CG graph with an indirect call we just added. + CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode()); +} + +// Make sure that there is a devirtualization trigger function that CoroSplit +// pass uses the force restart CGSCC pipeline. If devirt trigger function is not +// found, we will create one and add it to the current SCC. +static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) { + Module &M = CG.getModule(); + if (M.getFunction(CORO_DEVIRT_TRIGGER_FN)) + return; + + LLVMContext &C = M.getContext(); + auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C), + /*IsVarArgs=*/false); + Function *DevirtFn = + Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage, + CORO_DEVIRT_TRIGGER_FN, &M); + DevirtFn->addFnAttr(Attribute::AlwaysInline); + auto *Entry = BasicBlock::Create(C, "entry", DevirtFn); + ReturnInst::Create(C, Entry); + + auto *Node = CG.getOrInsertFunction(DevirtFn); + + SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end()); + Nodes.push_back(Node); + SCC.initialize(Nodes); +} + +//===----------------------------------------------------------------------===// +// Top Level Driver +//===----------------------------------------------------------------------===// + +namespace { + +struct CoroSplit : public CallGraphSCCPass { + static char ID; // Pass identification, replacement for typeid + CoroSplit() : CallGraphSCCPass(ID) {} + + bool Run = false; + + // A coroutine is identified by the presence of coro.begin intrinsic, if + // we don't have any, this pass has nothing to do. + bool doInitialization(CallGraph &CG) override { + Run = coro::declaresIntrinsics(CG.getModule(), {"llvm.coro.begin"}); + return CallGraphSCCPass::doInitialization(CG); + } + + bool runOnSCC(CallGraphSCC &SCC) override { + if (!Run) + return false; + + // Find coroutines for processing. + SmallVector<Function *, 4> Coroutines; + for (CallGraphNode *CGN : SCC) + if (auto *F = CGN->getFunction()) + if (F->hasFnAttribute(CORO_PRESPLIT_ATTR)) + Coroutines.push_back(F); + + if (Coroutines.empty()) + return false; + + CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); + createDevirtTriggerFunc(CG, SCC); + + for (Function *F : Coroutines) { + Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR); + StringRef Value = Attr.getValueAsString(); + DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName() + << "' state: " << Value << "\n"); + if (Value == UNPREPARED_FOR_SPLIT) { + prepareForSplit(*F, CG); + continue; + } + F->removeFnAttr(CORO_PRESPLIT_ATTR); + splitCoroutine(*F, CG, SCC); + } + return true; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + CallGraphSCCPass::getAnalysisUsage(AU); + } +}; +} + +char CoroSplit::ID = 0; +INITIALIZE_PASS( + CoroSplit, "coro-split", + "Split coroutine into a set of functions driving its state machine", false, + false) + +Pass *llvm::createCoroSplitPass() { return new CoroSplit(); } diff --git a/contrib/llvm/lib/Transforms/Coroutines/Coroutines.cpp b/contrib/llvm/lib/Transforms/Coroutines/Coroutines.cpp new file mode 100644 index 0000000..877ec34 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Coroutines/Coroutines.cpp @@ -0,0 +1,314 @@ +//===-- Coroutines.cpp ----------------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// This file implements the common infrastructure for Coroutine Passes. +//===----------------------------------------------------------------------===// + +#include "CoroInternal.h" +#include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/InitializePasses.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/PassManagerBuilder.h" +#include "llvm/Transforms/Utils/Local.h" + +using namespace llvm; + +void llvm::initializeCoroutines(PassRegistry &Registry) { + initializeCoroEarlyPass(Registry); + initializeCoroSplitPass(Registry); + initializeCoroElidePass(Registry); + initializeCoroCleanupPass(Registry); +} + +static void addCoroutineOpt0Passes(const PassManagerBuilder &Builder, + legacy::PassManagerBase &PM) { + PM.add(createCoroSplitPass()); + PM.add(createCoroElidePass()); + + PM.add(createBarrierNoopPass()); + PM.add(createCoroCleanupPass()); +} + +static void addCoroutineEarlyPasses(const PassManagerBuilder &Builder, + legacy::PassManagerBase &PM) { + PM.add(createCoroEarlyPass()); +} + +static void addCoroutineScalarOptimizerPasses(const PassManagerBuilder &Builder, + legacy::PassManagerBase &PM) { + PM.add(createCoroElidePass()); +} + +static void addCoroutineSCCPasses(const PassManagerBuilder &Builder, + legacy::PassManagerBase &PM) { + PM.add(createCoroSplitPass()); +} + +static void addCoroutineOptimizerLastPasses(const PassManagerBuilder &Builder, + legacy::PassManagerBase &PM) { + PM.add(createCoroCleanupPass()); +} + +void llvm::addCoroutinePassesToExtensionPoints(PassManagerBuilder &Builder) { + Builder.addExtension(PassManagerBuilder::EP_EarlyAsPossible, + addCoroutineEarlyPasses); + Builder.addExtension(PassManagerBuilder::EP_EnabledOnOptLevel0, + addCoroutineOpt0Passes); + Builder.addExtension(PassManagerBuilder::EP_CGSCCOptimizerLate, + addCoroutineSCCPasses); + Builder.addExtension(PassManagerBuilder::EP_ScalarOptimizerLate, + addCoroutineScalarOptimizerPasses); + Builder.addExtension(PassManagerBuilder::EP_OptimizerLast, + addCoroutineOptimizerLastPasses); +} + +// Construct the lowerer base class and initialize its members. +coro::LowererBase::LowererBase(Module &M) + : TheModule(M), Context(M.getContext()), + Int8Ptr(Type::getInt8PtrTy(Context)), + ResumeFnType(FunctionType::get(Type::getVoidTy(Context), Int8Ptr, + /*isVarArg=*/false)), + NullPtr(ConstantPointerNull::get(Int8Ptr)) {} + +// Creates a sequence of instructions to obtain a resume function address using +// llvm.coro.subfn.addr. It generates the following sequence: +// +// call i8* @llvm.coro.subfn.addr(i8* %Arg, i8 %index) +// bitcast i8* %2 to void(i8*)* + +Value *coro::LowererBase::makeSubFnCall(Value *Arg, int Index, + Instruction *InsertPt) { + auto *IndexVal = ConstantInt::get(Type::getInt8Ty(Context), Index); + auto *Fn = Intrinsic::getDeclaration(&TheModule, Intrinsic::coro_subfn_addr); + + assert(Index >= CoroSubFnInst::IndexFirst && + Index < CoroSubFnInst::IndexLast && + "makeSubFnCall: Index value out of range"); + auto *Call = CallInst::Create(Fn, {Arg, IndexVal}, "", InsertPt); + + auto *Bitcast = + new BitCastInst(Call, ResumeFnType->getPointerTo(), "", InsertPt); + return Bitcast; +} + +#ifndef NDEBUG +static bool isCoroutineIntrinsicName(StringRef Name) { + // NOTE: Must be sorted! + static const char *const CoroIntrinsics[] = { + "llvm.coro.alloc", "llvm.coro.begin", "llvm.coro.destroy", + "llvm.coro.done", "llvm.coro.end", "llvm.coro.frame", + "llvm.coro.free", "llvm.coro.id", "llvm.coro.param", + "llvm.coro.promise", "llvm.coro.resume", "llvm.coro.save", + "llvm.coro.size", "llvm.coro.subfn.addr", "llvm.coro.suspend", + }; + return Intrinsic::lookupLLVMIntrinsicByName(CoroIntrinsics, Name) != -1; +} +#endif + +// Verifies if a module has named values listed. Also, in debug mode verifies +// that names are intrinsic names. +bool coro::declaresIntrinsics(Module &M, + std::initializer_list<StringRef> List) { + + for (StringRef Name : List) { + assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic"); + if (M.getNamedValue(Name)) + return true; + } + + return false; +} + +// Replace all coro.frees associated with the provided CoroId either with 'null' +// if Elide is true and with its frame parameter otherwise. +void coro::replaceCoroFree(CoroIdInst *CoroId, bool Elide) { + SmallVector<CoroFreeInst *, 4> CoroFrees; + for (User *U : CoroId->users()) + if (auto CF = dyn_cast<CoroFreeInst>(U)) + CoroFrees.push_back(CF); + + if (CoroFrees.empty()) + return; + + Value *Replacement = + Elide ? ConstantPointerNull::get(Type::getInt8PtrTy(CoroId->getContext())) + : CoroFrees.front()->getFrame(); + + for (CoroFreeInst *CF : CoroFrees) { + CF->replaceAllUsesWith(Replacement); + CF->eraseFromParent(); + } +} + +// FIXME: This code is stolen from CallGraph::addToCallGraph(Function *F), which +// happens to be private. It is better for this functionality exposed by the +// CallGraph. +static void buildCGN(CallGraph &CG, CallGraphNode *Node) { + Function *F = Node->getFunction(); + + // Look for calls by this function. + for (Instruction &I : instructions(F)) + if (CallSite CS = CallSite(cast<Value>(&I))) { + const Function *Callee = CS.getCalledFunction(); + if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID())) + // Indirect calls of intrinsics are not allowed so no need to check. + // We can be more precise here by using TargetArg returned by + // Intrinsic::isLeaf. + Node->addCalledFunction(CS, CG.getCallsExternalNode()); + else if (!Callee->isIntrinsic()) + Node->addCalledFunction(CS, CG.getOrInsertFunction(Callee)); + } +} + +// Rebuild CGN after we extracted parts of the code from ParentFunc into +// NewFuncs. Builds CGNs for the NewFuncs and adds them to the current SCC. +void coro::updateCallGraph(Function &ParentFunc, ArrayRef<Function *> NewFuncs, + CallGraph &CG, CallGraphSCC &SCC) { + // Rebuild CGN from scratch for the ParentFunc + auto *ParentNode = CG[&ParentFunc]; + ParentNode->removeAllCalledFunctions(); + buildCGN(CG, ParentNode); + + SmallVector<CallGraphNode *, 8> Nodes(SCC.begin(), SCC.end()); + + for (Function *F : NewFuncs) { + CallGraphNode *Callee = CG.getOrInsertFunction(F); + Nodes.push_back(Callee); + buildCGN(CG, Callee); + } + + SCC.initialize(Nodes); +} + +static void clear(coro::Shape &Shape) { + Shape.CoroBegin = nullptr; + Shape.CoroEnds.clear(); + Shape.CoroSizes.clear(); + Shape.CoroSuspends.clear(); + + Shape.FrameTy = nullptr; + Shape.FramePtr = nullptr; + Shape.AllocaSpillBlock = nullptr; + Shape.ResumeSwitch = nullptr; + Shape.PromiseAlloca = nullptr; + Shape.HasFinalSuspend = false; +} + +static CoroSaveInst *createCoroSave(CoroBeginInst *CoroBegin, + CoroSuspendInst *SuspendInst) { + Module *M = SuspendInst->getModule(); + auto *Fn = Intrinsic::getDeclaration(M, Intrinsic::coro_save); + auto *SaveInst = + cast<CoroSaveInst>(CallInst::Create(Fn, CoroBegin, "", SuspendInst)); + assert(!SuspendInst->getCoroSave()); + SuspendInst->setArgOperand(0, SaveInst); + return SaveInst; +} + +// Collect "interesting" coroutine intrinsics. +void coro::Shape::buildFrom(Function &F) { + size_t FinalSuspendIndex = 0; + clear(*this); + SmallVector<CoroFrameInst *, 8> CoroFrames; + for (Instruction &I : instructions(F)) { + if (auto II = dyn_cast<IntrinsicInst>(&I)) { + switch (II->getIntrinsicID()) { + default: + continue; + case Intrinsic::coro_size: + CoroSizes.push_back(cast<CoroSizeInst>(II)); + break; + case Intrinsic::coro_frame: + CoroFrames.push_back(cast<CoroFrameInst>(II)); + break; + case Intrinsic::coro_suspend: + CoroSuspends.push_back(cast<CoroSuspendInst>(II)); + if (CoroSuspends.back()->isFinal()) { + if (HasFinalSuspend) + report_fatal_error( + "Only one suspend point can be marked as final"); + HasFinalSuspend = true; + FinalSuspendIndex = CoroSuspends.size() - 1; + } + break; + case Intrinsic::coro_begin: { + auto CB = cast<CoroBeginInst>(II); + if (CB->getId()->getInfo().isPreSplit()) { + if (CoroBegin) + report_fatal_error( + "coroutine should have exactly one defining @llvm.coro.begin"); + CB->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); + CB->addAttribute(AttributeSet::ReturnIndex, Attribute::NoAlias); + CB->removeAttribute(AttributeSet::FunctionIndex, + Attribute::NoDuplicate); + CoroBegin = CB; + } + break; + } + case Intrinsic::coro_end: + CoroEnds.push_back(cast<CoroEndInst>(II)); + if (CoroEnds.back()->isFallthrough()) { + // Make sure that the fallthrough coro.end is the first element in the + // CoroEnds vector. + if (CoroEnds.size() > 1) { + if (CoroEnds.front()->isFallthrough()) + report_fatal_error( + "Only one coro.end can be marked as fallthrough"); + std::swap(CoroEnds.front(), CoroEnds.back()); + } + } + break; + } + } + } + + // If for some reason, we were not able to find coro.begin, bailout. + if (!CoroBegin) { + // Replace coro.frame which are supposed to be lowered to the result of + // coro.begin with undef. + auto *Undef = UndefValue::get(Type::getInt8PtrTy(F.getContext())); + for (CoroFrameInst *CF : CoroFrames) { + CF->replaceAllUsesWith(Undef); + CF->eraseFromParent(); + } + + // Replace all coro.suspend with undef and remove related coro.saves if + // present. + for (CoroSuspendInst *CS : CoroSuspends) { + CS->replaceAllUsesWith(UndefValue::get(CS->getType())); + CS->eraseFromParent(); + if (auto *CoroSave = CS->getCoroSave()) + CoroSave->eraseFromParent(); + } + + // Replace all coro.ends with unreachable instruction. + for (CoroEndInst *CE : CoroEnds) + changeToUnreachable(CE, /*UseLLVMTrap=*/false); + + return; + } + + // The coro.free intrinsic is always lowered to the result of coro.begin. + for (CoroFrameInst *CF : CoroFrames) { + CF->replaceAllUsesWith(CoroBegin); + CF->eraseFromParent(); + } + + // Canonicalize coro.suspend by inserting a coro.save if needed. + for (CoroSuspendInst *CS : CoroSuspends) + if (!CS->getCoroSave()) + createCoroSave(CoroBegin, CS); + + // Move final suspend to be the last element in the CoroSuspends vector. + if (HasFinalSuspend && + FinalSuspendIndex != CoroSuspends.size() - 1) + std::swap(CoroSuspends[FinalSuspendIndex], CoroSuspends.back()); +} diff --git a/contrib/llvm/lib/Transforms/IPO/AlwaysInliner.cpp b/contrib/llvm/lib/Transforms/IPO/AlwaysInliner.cpp new file mode 100644 index 0000000..b7d9600 --- /dev/null +++ b/contrib/llvm/lib/Transforms/IPO/AlwaysInliner.cpp @@ -0,0 +1,158 @@ +//===- InlineAlways.cpp - Code to inline always_inline functions ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a custom inliner that handles only functions that +// are marked as "always inline". +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/AlwaysInliner.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/CallingConv.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/Inliner.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "inline" + +PreservedAnalyses AlwaysInlinerPass::run(Module &M, ModuleAnalysisManager &) { + InlineFunctionInfo IFI; + SmallSetVector<CallSite, 16> Calls; + bool Changed = false; + SmallVector<Function *, 16> InlinedFunctions; + for (Function &F : M) + if (!F.isDeclaration() && F.hasFnAttribute(Attribute::AlwaysInline) && + isInlineViable(F)) { + Calls.clear(); + + for (User *U : F.users()) + if (auto CS = CallSite(U)) + if (CS.getCalledFunction() == &F) + Calls.insert(CS); + + for (CallSite CS : Calls) + // FIXME: We really shouldn't be able to fail to inline at this point! + // We should do something to log or check the inline failures here. + Changed |= InlineFunction(CS, IFI); + + // Remember to try and delete this function afterward. This both avoids + // re-walking the rest of the module and avoids dealing with any iterator + // invalidation issues while deleting functions. + InlinedFunctions.push_back(&F); + } + + // Remove any live functions. + erase_if(InlinedFunctions, [&](Function *F) { + F->removeDeadConstantUsers(); + return !F->isDefTriviallyDead(); + }); + + // Delete the non-comdat ones from the module and also from our vector. + auto NonComdatBegin = partition( + InlinedFunctions, [&](Function *F) { return F->hasComdat(); }); + for (Function *F : make_range(NonComdatBegin, InlinedFunctions.end())) + M.getFunctionList().erase(F); + InlinedFunctions.erase(NonComdatBegin, InlinedFunctions.end()); + + if (!InlinedFunctions.empty()) { + // Now we just have the comdat functions. Filter out the ones whose comdats + // are not actually dead. + filterDeadComdatFunctions(M, InlinedFunctions); + // The remaining functions are actually dead. + for (Function *F : InlinedFunctions) + M.getFunctionList().erase(F); + } + + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} + +namespace { + +/// Inliner pass which only handles "always inline" functions. +/// +/// Unlike the \c AlwaysInlinerPass, this uses the more heavyweight \c Inliner +/// base class to provide several facilities such as array alloca merging. +class AlwaysInlinerLegacyPass : public LegacyInlinerBase { + +public: + AlwaysInlinerLegacyPass() : LegacyInlinerBase(ID, /*InsertLifetime*/ true) { + initializeAlwaysInlinerLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + AlwaysInlinerLegacyPass(bool InsertLifetime) + : LegacyInlinerBase(ID, InsertLifetime) { + initializeAlwaysInlinerLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + /// Main run interface method. We override here to avoid calling skipSCC(). + bool runOnSCC(CallGraphSCC &SCC) override { return inlineCalls(SCC); } + + static char ID; // Pass identification, replacement for typeid + + InlineCost getInlineCost(CallSite CS) override; + + using llvm::Pass::doFinalization; + bool doFinalization(CallGraph &CG) override { + return removeDeadFunctions(CG, /*AlwaysInlineOnly=*/true); + } +}; +} + +char AlwaysInlinerLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(AlwaysInlinerLegacyPass, "always-inline", + "Inliner for always_inline functions", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(AlwaysInlinerLegacyPass, "always-inline", + "Inliner for always_inline functions", false, false) + +Pass *llvm::createAlwaysInlinerLegacyPass(bool InsertLifetime) { + return new AlwaysInlinerLegacyPass(InsertLifetime); +} + +/// \brief Get the inline cost for the always-inliner. +/// +/// The always inliner *only* handles functions which are marked with the +/// attribute to force inlining. As such, it is dramatically simpler and avoids +/// using the powerful (but expensive) inline cost analysis. Instead it uses +/// a very simple and boring direct walk of the instructions looking for +/// impossible-to-inline constructs. +/// +/// Note, it would be possible to go to some lengths to cache the information +/// computed here, but as we only expect to do this for relatively few and +/// small functions which have the explicit attribute to force inlining, it is +/// likely not worth it in practice. +InlineCost AlwaysInlinerLegacyPass::getInlineCost(CallSite CS) { + Function *Callee = CS.getCalledFunction(); + + // Only inline direct calls to functions with always-inline attributes + // that are viable for inlining. FIXME: We shouldn't even get here for + // declarations. + if (Callee && !Callee->isDeclaration() && + CS.hasFnAttr(Attribute::AlwaysInline) && isInlineViable(*Callee)) + return InlineCost::getAlways(); + + return InlineCost::getNever(); +} diff --git a/contrib/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp b/contrib/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp index 0716a3a..65b7bad 100644 --- a/contrib/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/contrib/llvm/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -40,7 +40,6 @@ #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" @@ -177,8 +176,7 @@ static bool isDenselyPacked(Type *type, const DataLayout &DL) { // For homogenous sequential types, check for padding within members. if (SequentialType *seqTy = dyn_cast<SequentialType>(type)) - return isa<PointerType>(seqTy) || - isDenselyPacked(seqTy->getElementType(), DL); + return isDenselyPacked(seqTy->getElementType(), DL); // Check for padding within and between elements of a struct. StructType *StructTy = cast<StructType>(type); @@ -375,8 +373,8 @@ static bool AllCallersPassInValidPointerForArgument(Argument *Arg) { unsigned ArgNo = Arg->getArgNo(); - // Look at all call sites of the function. At this pointer we know we only - // have direct callees. + // Look at all call sites of the function. At this point we know we only have + // direct callees. for (User *U : Callee->users()) { CallSite CS(U); assert(CS && "Should only have direct calls!"); @@ -600,7 +598,7 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, // Because there could be several/many load instructions, remember which // blocks we know to be transparent to the load. - SmallPtrSet<BasicBlock*, 16> TranspBlocks; + df_iterator_default_set<BasicBlock*, 16> TranspBlocks; for (LoadInst *Load : Loads) { // Check to see if the load is invalidated from the start of the block to @@ -836,7 +834,10 @@ DoPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, Type::getInt64Ty(F->getContext())); Ops.push_back(ConstantInt::get(IdxTy, II)); // Keep track of the type we're currently indexing. - ElTy = cast<CompositeType>(ElTy)->getTypeAtIndex(II); + if (auto *ElPTy = dyn_cast<PointerType>(ElTy)) + ElTy = ElPTy->getElementType(); + else + ElTy = cast<CompositeType>(ElTy)->getTypeAtIndex(II); } // And create a GEP to extract those indices. V = GetElementPtrInst::Create(ArgIndex.first, V, Ops, @@ -886,8 +887,8 @@ DoPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, cast<CallInst>(New)->setCallingConv(CS.getCallingConv()); cast<CallInst>(New)->setAttributes(AttributeSet::get(New->getContext(), AttributesVec)); - if (cast<CallInst>(Call)->isTailCall()) - cast<CallInst>(New)->setTailCall(); + cast<CallInst>(New)->setTailCallKind( + cast<CallInst>(Call)->getTailCallKind()); } New->setDebugLoc(Call->getDebugLoc()); Args.clear(); diff --git a/contrib/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp b/contrib/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp index 58731ea..ba2e60d 100644 --- a/contrib/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/contrib/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -155,7 +155,7 @@ bool CrossDSOCFI::runOnModule(Module &M) { return true; } -PreservedAnalyses CrossDSOCFIPass::run(Module &M, AnalysisManager<Module> &AM) { +PreservedAnalyses CrossDSOCFIPass::run(Module &M, ModuleAnalysisManager &AM) { CrossDSOCFI Impl; bool Changed = Impl.runOnModule(M); if (!Changed) diff --git a/contrib/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp b/contrib/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp index c8c895b..1a5ed46 100644 --- a/contrib/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/contrib/llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -190,8 +190,8 @@ bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) { New = CallInst::Create(NF, Args, OpBundles, "", Call); cast<CallInst>(New)->setCallingConv(CS.getCallingConv()); cast<CallInst>(New)->setAttributes(PAL); - if (cast<CallInst>(Call)->isTailCall()) - cast<CallInst>(New)->setTailCall(); + cast<CallInst>(New)->setTailCallKind( + cast<CallInst>(Call)->getTailCallKind()); } New->setDebugLoc(Call->getDebugLoc()); @@ -270,7 +270,7 @@ bool DeadArgumentEliminationPass::RemoveDeadArgumentsFromCallers(Function &Fn) { SmallVector<unsigned, 8> UnusedArgs; for (Argument &Arg : Fn.args()) { - if (Arg.use_empty() && !Arg.hasByValOrInAllocaAttr()) + if (!Arg.hasSwiftErrorAttr() && Arg.use_empty() && !Arg.hasByValOrInAllocaAttr()) UnusedArgs.push_back(Arg.getArgNo()); } @@ -896,8 +896,8 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { New = CallInst::Create(NF, Args, OpBundles, "", Call); cast<CallInst>(New)->setCallingConv(CS.getCallingConv()); cast<CallInst>(New)->setAttributes(NewCallPAL); - if (cast<CallInst>(Call)->isTailCall()) - cast<CallInst>(New)->setTailCall(); + cast<CallInst>(New)->setTailCallKind( + cast<CallInst>(Call)->getTailCallKind()); } New->setDebugLoc(Call->getDebugLoc()); diff --git a/contrib/llvm/lib/Transforms/IPO/FunctionAttrs.cpp b/contrib/llvm/lib/Transforms/IPO/FunctionAttrs.cpp index 787f434..402a665 100644 --- a/contrib/llvm/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/contrib/llvm/lib/Transforms/IPO/FunctionAttrs.cpp @@ -42,6 +42,7 @@ using namespace llvm; STATISTIC(NumReadNone, "Number of functions marked readnone"); STATISTIC(NumReadOnly, "Number of functions marked readonly"); STATISTIC(NumNoCapture, "Number of arguments marked nocapture"); +STATISTIC(NumReturned, "Number of arguments marked returned"); STATISTIC(NumReadNoneArg, "Number of arguments marked readnone"); STATISTIC(NumReadOnlyArg, "Number of arguments marked readonly"); STATISTIC(NumNoAlias, "Number of function returns marked noalias"); @@ -331,23 +332,16 @@ struct ArgumentUsesTracker : public CaptureTracker { namespace llvm { template <> struct GraphTraits<ArgumentGraphNode *> { - typedef ArgumentGraphNode NodeType; typedef ArgumentGraphNode *NodeRef; typedef SmallVectorImpl<ArgumentGraphNode *>::iterator ChildIteratorType; - static inline NodeType *getEntryNode(NodeType *A) { return A; } - static inline ChildIteratorType child_begin(NodeType *N) { - return N->Uses.begin(); - } - static inline ChildIteratorType child_end(NodeType *N) { - return N->Uses.end(); - } + static NodeRef getEntryNode(NodeRef A) { return A; } + static ChildIteratorType child_begin(NodeRef N) { return N->Uses.begin(); } + static ChildIteratorType child_end(NodeRef N) { return N->Uses.end(); } }; template <> struct GraphTraits<ArgumentGraph *> : public GraphTraits<ArgumentGraphNode *> { - static NodeType *getEntryNode(ArgumentGraph *AG) { - return AG->getEntryNode(); - } + static NodeRef getEntryNode(ArgumentGraph *AG) { return AG->getEntryNode(); } static ChildIteratorType nodes_begin(ArgumentGraph *AG) { return AG->begin(); } @@ -447,8 +441,8 @@ determinePointerReadAttrs(Argument *A, // to a operand bundle use, these cannot participate in the optimistic SCC // analysis. Instead, we model the operand bundle uses as arguments in // call to a function external to the SCC. - if (!SCCNodes.count(&*std::next(F->arg_begin(), UseIndex)) || - IsOperandBundleUse) { + if (IsOperandBundleUse || + !SCCNodes.count(&*std::next(F->arg_begin(), UseIndex))) { // The accessors used on CallSite here do the right thing for calls and // invokes with operand bundles. @@ -484,6 +478,59 @@ determinePointerReadAttrs(Argument *A, return IsRead ? Attribute::ReadOnly : Attribute::ReadNone; } +/// Deduce returned attributes for the SCC. +static bool addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes) { + bool Changed = false; + + AttrBuilder B; + B.addAttribute(Attribute::Returned); + + // Check each function in turn, determining if an argument is always returned. + for (Function *F : SCCNodes) { + // We can infer and propagate function attributes only when we know that the + // definition we'll get at link time is *exactly* the definition we see now. + // For more details, see GlobalValue::mayBeDerefined. + if (!F->hasExactDefinition()) + continue; + + if (F->getReturnType()->isVoidTy()) + continue; + + // There is nothing to do if an argument is already marked as 'returned'. + if (any_of(F->args(), + [](const Argument &Arg) { return Arg.hasReturnedAttr(); })) + continue; + + auto FindRetArg = [&]() -> Value * { + Value *RetArg = nullptr; + for (BasicBlock &BB : *F) + if (auto *Ret = dyn_cast<ReturnInst>(BB.getTerminator())) { + // Note that stripPointerCasts should look through functions with + // returned arguments. + Value *RetVal = Ret->getReturnValue()->stripPointerCasts(); + if (!isa<Argument>(RetVal) || RetVal->getType() != F->getReturnType()) + return nullptr; + + if (!RetArg) + RetArg = RetVal; + else if (RetArg != RetVal) + return nullptr; + } + + return RetArg; + }; + + if (Value *RetArg = FindRetArg()) { + auto *A = cast<Argument>(RetArg); + A->addAttr(AttributeSet::get(F->getContext(), A->getArgNo() + 1, B)); + ++NumReturned; + Changed = true; + } + } + + return Changed; +} + /// Deduce nocapture attributes for the SCC. static bool addArgumentAttrs(const SCCNodeSet &SCCNodes) { bool Changed = false; @@ -726,7 +773,8 @@ static bool isFunctionMallocLike(Function *F, const SCCNodeSet &SCCNodes) { break; if (CS.getCalledFunction() && SCCNodes.count(CS.getCalledFunction())) break; - } // fall-through + LLVM_FALLTHROUGH; + } default: return false; // Did not come from an allocation. } @@ -986,9 +1034,11 @@ static bool addNoRecurseAttrs(const SCCNodeSet &SCCNodes) { } PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, - CGSCCAnalysisManager &AM) { + CGSCCAnalysisManager &AM, + LazyCallGraph &CG, + CGSCCUpdateResult &) { FunctionAnalysisManager &FAM = - AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C).getManager(); + AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager(); // We pass a lambda into functions to wire them up to the analysis manager // for getting function analyses. @@ -1025,6 +1075,7 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, } bool Changed = false; + Changed |= addArgumentReturnedAttrs(SCCNodes); Changed |= addReadAttrs(SCCNodes, AARGetter); Changed |= addArgumentAttrs(SCCNodes); @@ -1044,7 +1095,8 @@ namespace { struct PostOrderFunctionAttrsLegacyPass : public CallGraphSCCPass { static char ID; // Pass identification, replacement for typeid PostOrderFunctionAttrsLegacyPass() : CallGraphSCCPass(ID) { - initializePostOrderFunctionAttrsLegacyPassPass(*PassRegistry::getPassRegistry()); + initializePostOrderFunctionAttrsLegacyPassPass( + *PassRegistry::getPassRegistry()); } bool runOnSCC(CallGraphSCC &SCC) override; @@ -1066,7 +1118,9 @@ INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) INITIALIZE_PASS_END(PostOrderFunctionAttrsLegacyPass, "functionattrs", "Deduce function attributes", false, false) -Pass *llvm::createPostOrderFunctionAttrsLegacyPass() { return new PostOrderFunctionAttrsLegacyPass(); } +Pass *llvm::createPostOrderFunctionAttrsLegacyPass() { + return new PostOrderFunctionAttrsLegacyPass(); +} template <typename AARGetterT> static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) { @@ -1090,6 +1144,7 @@ static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) { SCCNodes.insert(F); } + Changed |= addArgumentReturnedAttrs(SCCNodes); Changed |= addReadAttrs(SCCNodes, AARGetter); Changed |= addArgumentAttrs(SCCNodes); @@ -1127,7 +1182,8 @@ namespace { struct ReversePostOrderFunctionAttrsLegacyPass : public ModulePass { static char ID; // Pass identification, replacement for typeid ReversePostOrderFunctionAttrsLegacyPass() : ModulePass(ID) { - initializeReversePostOrderFunctionAttrsLegacyPassPass(*PassRegistry::getPassRegistry()); + initializeReversePostOrderFunctionAttrsLegacyPassPass( + *PassRegistry::getPassRegistry()); } bool runOnModule(Module &M) override; @@ -1216,10 +1272,17 @@ bool ReversePostOrderFunctionAttrsLegacyPass::runOnModule(Module &M) { } PreservedAnalyses -ReversePostOrderFunctionAttrsPass::run(Module &M, AnalysisManager<Module> &AM) { +ReversePostOrderFunctionAttrsPass::run(Module &M, ModuleAnalysisManager &AM) { auto &CG = AM.getResult<CallGraphAnalysis>(M); bool Changed = deduceFunctionAttributeInRPO(M, CG); + + // CallGraphAnalysis holds AssertingVH and must be invalidated eagerly so + // that other passes don't delete stuff from under it. + // FIXME: We need to invalidate this to avoid PR28400. Is there a better + // solution? + AM.invalidate<CallGraphAnalysis>(M); + if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/contrib/llvm/lib/Transforms/IPO/FunctionImport.cpp b/contrib/llvm/lib/Transforms/IPO/FunctionImport.cpp index c9d075e..6b32f6c 100644 --- a/contrib/llvm/lib/Transforms/IPO/FunctionImport.cpp +++ b/contrib/llvm/lib/Transforms/IPO/FunctionImport.cpp @@ -21,6 +21,7 @@ #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Linker/Linker.h" #include "llvm/Object/IRObjectFile.h" @@ -35,7 +36,10 @@ using namespace llvm; -STATISTIC(NumImported, "Number of functions imported"); +STATISTIC(NumImportedFunctions, "Number of functions imported"); +STATISTIC(NumImportedModules, "Number of modules imported from"); +STATISTIC(NumDeadSymbols, "Number of dead stripped symbols in index"); +STATISTIC(NumLiveSymbols, "Number of live symbols in index"); /// Limit on instruction count of imported functions. static cl::opt<unsigned> ImportInstrLimit( @@ -49,9 +53,28 @@ static cl::opt<float> "`import-instr-limit` threshold by this factor " "before processing newly imported functions")); +static cl::opt<float> ImportHotInstrFactor( + "import-hot-evolution-factor", cl::init(1.0), cl::Hidden, + cl::value_desc("x"), + cl::desc("As we import functions called from hot callsite, multiply the " + "`import-instr-limit` threshold by this factor " + "before processing newly imported functions")); + +static cl::opt<float> ImportHotMultiplier( + "import-hot-multiplier", cl::init(3.0), cl::Hidden, cl::value_desc("x"), + cl::desc("Multiply the `import-instr-limit` threshold for hot callsites")); + +// FIXME: This multiplier was not really tuned up. +static cl::opt<float> ImportColdMultiplier( + "import-cold-multiplier", cl::init(0), cl::Hidden, cl::value_desc("N"), + cl::desc("Multiply the `import-instr-limit` threshold for cold callsites")); + static cl::opt<bool> PrintImports("print-imports", cl::init(false), cl::Hidden, cl::desc("Print imported functions")); +static cl::opt<bool> ComputeDead("compute-dead", cl::init(true), cl::Hidden, + cl::desc("Compute dead symbols")); + // Temporary allows the function import pass to disable always linking // referenced discardable symbols. static cl::opt<bool> @@ -88,69 +111,6 @@ static std::unique_ptr<Module> loadFile(const std::string &FileName, namespace { -// Return true if the Summary describes a GlobalValue that can be externally -// referenced, i.e. it does not need renaming (linkage is not local) or renaming -// is possible (does not have a section for instance). -static bool canBeExternallyReferenced(const GlobalValueSummary &Summary) { - if (!Summary.needsRenaming()) - return true; - - if (Summary.hasSection()) - // Can't rename a global that needs renaming if has a section. - return false; - - return true; -} - -// Return true if \p GUID describes a GlobalValue that can be externally -// referenced, i.e. it does not need renaming (linkage is not local) or -// renaming is possible (does not have a section for instance). -static bool canBeExternallyReferenced(const ModuleSummaryIndex &Index, - GlobalValue::GUID GUID) { - auto Summaries = Index.findGlobalValueSummaryList(GUID); - if (Summaries == Index.end()) - return true; - if (Summaries->second.size() != 1) - // If there are multiple globals with this GUID, then we know it is - // not a local symbol, and it is necessarily externally referenced. - return true; - - // We don't need to check for the module path, because if it can't be - // externally referenced and we call it, it is necessarilly in the same - // module - return canBeExternallyReferenced(**Summaries->second.begin()); -} - -// Return true if the global described by \p Summary can be imported in another -// module. -static bool eligibleForImport(const ModuleSummaryIndex &Index, - const GlobalValueSummary &Summary) { - if (!canBeExternallyReferenced(Summary)) - // Can't import a global that needs renaming if has a section for instance. - // FIXME: we may be able to import it by copying it without promotion. - return false; - - // Check references (and potential calls) in the same module. If the current - // value references a global that can't be externally referenced it is not - // eligible for import. - bool AllRefsCanBeExternallyReferenced = - llvm::all_of(Summary.refs(), [&](const ValueInfo &VI) { - return canBeExternallyReferenced(Index, VI.getGUID()); - }); - if (!AllRefsCanBeExternallyReferenced) - return false; - - if (auto *FuncSummary = dyn_cast<FunctionSummary>(&Summary)) { - bool AllCallsCanBeExternallyReferenced = llvm::all_of( - FuncSummary->calls(), [&](const FunctionSummary::EdgeTy &Edge) { - return canBeExternallyReferenced(Index, Edge.first.getGUID()); - }); - if (!AllCallsCanBeExternallyReferenced) - return false; - } - return true; -} - /// Given a list of possible callee implementation for a call site, select one /// that fits the \p Threshold. /// @@ -188,7 +148,7 @@ selectCallee(const ModuleSummaryIndex &Index, if (Summary->instCount() > Threshold) return false; - if (!eligibleForImport(Index, *Summary)) + if (Summary->notEligibleToImport()) return false; return true; @@ -210,63 +170,17 @@ static const GlobalValueSummary *selectCallee(GlobalValue::GUID GUID, return selectCallee(Index, CalleeSummaryList->second, Threshold); } -/// Mark the global \p GUID as export by module \p ExportModulePath if found in -/// this module. If it is a GlobalVariable, we also mark any referenced global -/// in the current module as exported. -static void exportGlobalInModule(const ModuleSummaryIndex &Index, - StringRef ExportModulePath, - GlobalValue::GUID GUID, - FunctionImporter::ExportSetTy &ExportList) { - auto FindGlobalSummaryInModule = - [&](GlobalValue::GUID GUID) -> GlobalValueSummary *{ - auto SummaryList = Index.findGlobalValueSummaryList(GUID); - if (SummaryList == Index.end()) - // This global does not have a summary, it is not part of the ThinLTO - // process - return nullptr; - auto SummaryIter = llvm::find_if( - SummaryList->second, - [&](const std::unique_ptr<GlobalValueSummary> &Summary) { - return Summary->modulePath() == ExportModulePath; - }); - if (SummaryIter == SummaryList->second.end()) - return nullptr; - return SummaryIter->get(); - }; - - auto *Summary = FindGlobalSummaryInModule(GUID); - if (!Summary) - return; - // We found it in the current module, mark as exported - ExportList.insert(GUID); - - auto GVS = dyn_cast<GlobalVarSummary>(Summary); - if (!GVS) - return; - // FunctionImportGlobalProcessing::doPromoteLocalToGlobal() will always - // trigger importing the initializer for `constant unnamed addr` globals that - // are referenced. We conservatively export all the referenced symbols for - // every global to workaround this, so that the ExportList is accurate. - // FIXME: with a "isConstant" flag in the summary we could be more targetted. - for (auto &Ref : GVS->refs()) { - auto GUID = Ref.getGUID(); - auto *RefSummary = FindGlobalSummaryInModule(GUID); - if (RefSummary) - // Found a ref in the current module, mark it as exported - ExportList.insert(GUID); - } -} - -using EdgeInfo = std::pair<const FunctionSummary *, unsigned /* Threshold */>; +using EdgeInfo = std::tuple<const FunctionSummary *, unsigned /* Threshold */, + GlobalValue::GUID>; /// Compute the list of functions to import for a given caller. Mark these /// imported functions and the symbols they reference in their source module as /// exported from their source module. static void computeImportForFunction( const FunctionSummary &Summary, const ModuleSummaryIndex &Index, - unsigned Threshold, const GVSummaryMapTy &DefinedGVSummaries, + const unsigned Threshold, const GVSummaryMapTy &DefinedGVSummaries, SmallVectorImpl<EdgeInfo> &Worklist, - FunctionImporter::ImportMapTy &ImportsForModule, + FunctionImporter::ImportMapTy &ImportList, StringMap<FunctionImporter::ExportSetTy> *ExportLists = nullptr) { for (auto &Edge : Summary.calls()) { auto GUID = Edge.first.getGUID(); @@ -277,7 +191,18 @@ static void computeImportForFunction( continue; } - auto *CalleeSummary = selectCallee(GUID, Threshold, Index); + auto GetBonusMultiplier = [](CalleeInfo::HotnessType Hotness) -> float { + if (Hotness == CalleeInfo::HotnessType::Hot) + return ImportHotMultiplier; + if (Hotness == CalleeInfo::HotnessType::Cold) + return ImportColdMultiplier; + return 1.0; + }; + + const auto NewThreshold = + Threshold * GetBonusMultiplier(Edge.second.Hotness); + + auto *CalleeSummary = selectCallee(GUID, NewThreshold, Index); if (!CalleeSummary) { DEBUG(dbgs() << "ignored! No qualifying callee with summary found.\n"); continue; @@ -293,40 +218,59 @@ static void computeImportForFunction( } else ResolvedCalleeSummary = cast<FunctionSummary>(CalleeSummary); - assert(ResolvedCalleeSummary->instCount() <= Threshold && + assert(ResolvedCalleeSummary->instCount() <= NewThreshold && "selectCallee() didn't honor the threshold"); + auto GetAdjustedThreshold = [](unsigned Threshold, bool IsHotCallsite) { + // Adjust the threshold for next level of imported functions. + // The threshold is different for hot callsites because we can then + // inline chains of hot calls. + if (IsHotCallsite) + return Threshold * ImportHotInstrFactor; + return Threshold * ImportInstrFactor; + }; + + bool IsHotCallsite = Edge.second.Hotness == CalleeInfo::HotnessType::Hot; + const auto AdjThreshold = GetAdjustedThreshold(Threshold, IsHotCallsite); + auto ExportModulePath = ResolvedCalleeSummary->modulePath(); - auto &ProcessedThreshold = ImportsForModule[ExportModulePath][GUID]; + auto &ProcessedThreshold = ImportList[ExportModulePath][GUID]; /// Since the traversal of the call graph is DFS, we can revisit a function /// a second time with a higher threshold. In this case, it is added back to /// the worklist with the new threshold. - if (ProcessedThreshold && ProcessedThreshold >= Threshold) { + if (ProcessedThreshold && ProcessedThreshold >= AdjThreshold) { DEBUG(dbgs() << "ignored! Target was already seen with Threshold " << ProcessedThreshold << "\n"); continue; } + bool PreviouslyImported = ProcessedThreshold != 0; // Mark this function as imported in this module, with the current Threshold - ProcessedThreshold = Threshold; + ProcessedThreshold = AdjThreshold; // Make exports in the source module. if (ExportLists) { auto &ExportList = (*ExportLists)[ExportModulePath]; ExportList.insert(GUID); - // Mark all functions and globals referenced by this function as exported - // to the outside if they are defined in the same source module. - for (auto &Edge : ResolvedCalleeSummary->calls()) { - auto CalleeGUID = Edge.first.getGUID(); - exportGlobalInModule(Index, ExportModulePath, CalleeGUID, ExportList); - } - for (auto &Ref : ResolvedCalleeSummary->refs()) { - auto GUID = Ref.getGUID(); - exportGlobalInModule(Index, ExportModulePath, GUID, ExportList); + if (!PreviouslyImported) { + // This is the first time this function was exported from its source + // module, so mark all functions and globals it references as exported + // to the outside if they are defined in the same source module. + // For efficiency, we unconditionally add all the referenced GUIDs + // to the ExportList for this module, and will prune out any not + // defined in the module later in a single pass. + for (auto &Edge : ResolvedCalleeSummary->calls()) { + auto CalleeGUID = Edge.first.getGUID(); + ExportList.insert(CalleeGUID); + } + for (auto &Ref : ResolvedCalleeSummary->refs()) { + auto GUID = Ref.getGUID(); + ExportList.insert(GUID); + } } } // Insert the newly imported function to the worklist. - Worklist.push_back(std::make_pair(ResolvedCalleeSummary, Threshold)); + Worklist.emplace_back(ResolvedCalleeSummary, AdjThreshold, GUID); } } @@ -335,8 +279,9 @@ static void computeImportForFunction( /// another module (that may require promotion). static void ComputeImportForModule( const GVSummaryMapTy &DefinedGVSummaries, const ModuleSummaryIndex &Index, - FunctionImporter::ImportMapTy &ImportsForModule, - StringMap<FunctionImporter::ExportSetTy> *ExportLists = nullptr) { + FunctionImporter::ImportMapTy &ImportList, + StringMap<FunctionImporter::ExportSetTy> *ExportLists = nullptr, + const DenseSet<GlobalValue::GUID> *DeadSymbols = nullptr) { // Worklist contains the list of function imported in this module, for which // we will analyse the callees and may import further down the callgraph. SmallVector<EdgeInfo, 128> Worklist; @@ -344,6 +289,10 @@ static void ComputeImportForModule( // Populate the worklist with the import for the functions in the current // module for (auto &GVSummary : DefinedGVSummaries) { + if (DeadSymbols && DeadSymbols->count(GVSummary.first)) { + DEBUG(dbgs() << "Ignores Dead GUID: " << GVSummary.first << "\n"); + continue; + } auto *Summary = GVSummary.second; if (auto *AS = dyn_cast<AliasSummary>(Summary)) Summary = &AS->getAliasee(); @@ -353,21 +302,26 @@ static void ComputeImportForModule( continue; DEBUG(dbgs() << "Initalize import for " << GVSummary.first << "\n"); computeImportForFunction(*FuncSummary, Index, ImportInstrLimit, - DefinedGVSummaries, Worklist, ImportsForModule, + DefinedGVSummaries, Worklist, ImportList, ExportLists); } + // Process the newly imported functions and add callees to the worklist. while (!Worklist.empty()) { auto FuncInfo = Worklist.pop_back_val(); - auto *Summary = FuncInfo.first; - auto Threshold = FuncInfo.second; - - // Process the newly imported functions and add callees to the worklist. - // Adjust the threshold - Threshold = Threshold * ImportInstrFactor; + auto *Summary = std::get<0>(FuncInfo); + auto Threshold = std::get<1>(FuncInfo); + auto GUID = std::get<2>(FuncInfo); + + // Check if we later added this summary with a higher threshold. + // If so, skip this entry. + auto ExportModulePath = Summary->modulePath(); + auto &LatestProcessedThreshold = ImportList[ExportModulePath][GUID]; + if (LatestProcessedThreshold > Threshold) + continue; computeImportForFunction(*Summary, Index, Threshold, DefinedGVSummaries, - Worklist, ImportsForModule, ExportLists); + Worklist, ImportList, ExportLists); } } @@ -378,14 +332,31 @@ void llvm::ComputeCrossModuleImport( const ModuleSummaryIndex &Index, const StringMap<GVSummaryMapTy> &ModuleToDefinedGVSummaries, StringMap<FunctionImporter::ImportMapTy> &ImportLists, - StringMap<FunctionImporter::ExportSetTy> &ExportLists) { + StringMap<FunctionImporter::ExportSetTy> &ExportLists, + const DenseSet<GlobalValue::GUID> *DeadSymbols) { // For each module that has function defined, compute the import/export lists. for (auto &DefinedGVSummaries : ModuleToDefinedGVSummaries) { - auto &ImportsForModule = ImportLists[DefinedGVSummaries.first()]; + auto &ImportList = ImportLists[DefinedGVSummaries.first()]; DEBUG(dbgs() << "Computing import for Module '" << DefinedGVSummaries.first() << "'\n"); - ComputeImportForModule(DefinedGVSummaries.second, Index, ImportsForModule, - &ExportLists); + ComputeImportForModule(DefinedGVSummaries.second, Index, ImportList, + &ExportLists, DeadSymbols); + } + + // When computing imports we added all GUIDs referenced by anything + // imported from the module to its ExportList. Now we prune each ExportList + // of any not defined in that module. This is more efficient than checking + // while computing imports because some of the summary lists may be long + // due to linkonce (comdat) copies. + for (auto &ELI : ExportLists) { + const auto &DefinedGVSummaries = + ModuleToDefinedGVSummaries.lookup(ELI.first()); + for (auto EI = ELI.second.begin(); EI != ELI.second.end();) { + if (!DefinedGVSummaries.count(*EI)) + EI = ELI.second.erase(EI); + else + ++EI; + } } #ifndef NDEBUG @@ -431,45 +402,120 @@ void llvm::ComputeCrossModuleImportForModule( #endif } +DenseSet<GlobalValue::GUID> llvm::computeDeadSymbols( + const ModuleSummaryIndex &Index, + const DenseSet<GlobalValue::GUID> &GUIDPreservedSymbols) { + if (!ComputeDead) + return DenseSet<GlobalValue::GUID>(); + if (GUIDPreservedSymbols.empty()) + // Don't do anything when nothing is live, this is friendly with tests. + return DenseSet<GlobalValue::GUID>(); + DenseSet<GlobalValue::GUID> LiveSymbols = GUIDPreservedSymbols; + SmallVector<GlobalValue::GUID, 128> Worklist; + Worklist.reserve(LiveSymbols.size() * 2); + for (auto GUID : LiveSymbols) { + DEBUG(dbgs() << "Live root: " << GUID << "\n"); + Worklist.push_back(GUID); + } + // Add values flagged in the index as live roots to the worklist. + for (const auto &Entry : Index) { + bool IsLiveRoot = llvm::any_of( + Entry.second, + [&](const std::unique_ptr<llvm::GlobalValueSummary> &Summary) { + return Summary->liveRoot(); + }); + if (!IsLiveRoot) + continue; + DEBUG(dbgs() << "Live root (summary): " << Entry.first << "\n"); + Worklist.push_back(Entry.first); + } + + while (!Worklist.empty()) { + auto GUID = Worklist.pop_back_val(); + auto It = Index.findGlobalValueSummaryList(GUID); + if (It == Index.end()) { + DEBUG(dbgs() << "Not in index: " << GUID << "\n"); + continue; + } + + // FIXME: we should only make the prevailing copy live here + for (auto &Summary : It->second) { + for (auto Ref : Summary->refs()) { + auto RefGUID = Ref.getGUID(); + if (LiveSymbols.insert(RefGUID).second) { + DEBUG(dbgs() << "Marking live (ref): " << RefGUID << "\n"); + Worklist.push_back(RefGUID); + } + } + if (auto *FS = dyn_cast<FunctionSummary>(Summary.get())) { + for (auto Call : FS->calls()) { + auto CallGUID = Call.first.getGUID(); + if (LiveSymbols.insert(CallGUID).second) { + DEBUG(dbgs() << "Marking live (call): " << CallGUID << "\n"); + Worklist.push_back(CallGUID); + } + } + } + if (auto *AS = dyn_cast<AliasSummary>(Summary.get())) { + auto AliaseeGUID = AS->getAliasee().getOriginalName(); + if (LiveSymbols.insert(AliaseeGUID).second) { + DEBUG(dbgs() << "Marking live (alias): " << AliaseeGUID << "\n"); + Worklist.push_back(AliaseeGUID); + } + } + } + } + DenseSet<GlobalValue::GUID> DeadSymbols; + DeadSymbols.reserve( + std::min(Index.size(), Index.size() - LiveSymbols.size())); + for (auto &Entry : Index) { + auto GUID = Entry.first; + if (!LiveSymbols.count(GUID)) { + DEBUG(dbgs() << "Marking dead: " << GUID << "\n"); + DeadSymbols.insert(GUID); + } + } + DEBUG(dbgs() << LiveSymbols.size() << " symbols Live, and " + << DeadSymbols.size() << " symbols Dead \n"); + NumDeadSymbols += DeadSymbols.size(); + NumLiveSymbols += LiveSymbols.size(); + return DeadSymbols; +} + /// Compute the set of summaries needed for a ThinLTO backend compilation of /// \p ModulePath. void llvm::gatherImportedSummariesForModule( StringRef ModulePath, const StringMap<GVSummaryMapTy> &ModuleToDefinedGVSummaries, - const StringMap<FunctionImporter::ImportMapTy> &ImportLists, + const FunctionImporter::ImportMapTy &ImportList, std::map<std::string, GVSummaryMapTy> &ModuleToSummariesForIndex) { // Include all summaries from the importing module. ModuleToSummariesForIndex[ModulePath] = ModuleToDefinedGVSummaries.lookup(ModulePath); - auto ModuleImports = ImportLists.find(ModulePath); - if (ModuleImports != ImportLists.end()) { - // Include summaries for imports. - for (auto &ILI : ModuleImports->second) { - auto &SummariesForIndex = ModuleToSummariesForIndex[ILI.first()]; - const auto &DefinedGVSummaries = - ModuleToDefinedGVSummaries.lookup(ILI.first()); - for (auto &GI : ILI.second) { - const auto &DS = DefinedGVSummaries.find(GI.first); - assert(DS != DefinedGVSummaries.end() && - "Expected a defined summary for imported global value"); - SummariesForIndex[GI.first] = DS->second; - } + // Include summaries for imports. + for (auto &ILI : ImportList) { + auto &SummariesForIndex = ModuleToSummariesForIndex[ILI.first()]; + const auto &DefinedGVSummaries = + ModuleToDefinedGVSummaries.lookup(ILI.first()); + for (auto &GI : ILI.second) { + const auto &DS = DefinedGVSummaries.find(GI.first); + assert(DS != DefinedGVSummaries.end() && + "Expected a defined summary for imported global value"); + SummariesForIndex[GI.first] = DS->second; } } } /// Emit the files \p ModulePath will import from into \p OutputFilename. -std::error_code llvm::EmitImportsFiles( - StringRef ModulePath, StringRef OutputFilename, - const StringMap<FunctionImporter::ImportMapTy> &ImportLists) { - auto ModuleImports = ImportLists.find(ModulePath); +std::error_code +llvm::EmitImportsFiles(StringRef ModulePath, StringRef OutputFilename, + const FunctionImporter::ImportMapTy &ModuleImports) { std::error_code EC; raw_fd_ostream ImportsOS(OutputFilename, EC, sys::fs::OpenFlags::F_None); if (EC) return EC; - if (ModuleImports != ImportLists.end()) - for (auto &ILI : ModuleImports->second) - ImportsOS << ILI.first() << "\n"; + for (auto &ILI : ModuleImports) + ImportsOS << ILI.first() << "\n"; return std::error_code(); } @@ -489,6 +535,15 @@ void llvm::thinLTOResolveWeakForLinkerModule( DEBUG(dbgs() << "ODR fixing up linkage for `" << GV.getName() << "` from " << GV.getLinkage() << " to " << NewLinkage << "\n"); GV.setLinkage(NewLinkage); + // Remove functions converted to available_externally from comdats, + // as this is a declaration for the linker, and will be dropped eventually. + // It is illegal for comdats to contain declarations. + auto *GO = dyn_cast_or_null<GlobalObject>(&GV); + if (GO && GO->isDeclarationForLinker() && GO->hasComdat()) { + assert(GO->hasAvailableExternallyLinkage() && + "Expected comdat on definition (possibly available external)"); + GO->setComdat(nullptr); + } }; // Process functions and global now @@ -506,7 +561,7 @@ void llvm::thinLTOInternalizeModule(Module &TheModule, // Parse inline ASM and collect the list of symbols that are not defined in // the current module. StringSet<> AsmUndefinedRefs; - object::IRObjectFile::CollectAsmUndefinedRefs( + ModuleSymbolTable::CollectAsmSymbols( Triple(TheModule.getTargetTriple()), TheModule.getModuleInlineAsm(), [&AsmUndefinedRefs](StringRef Name, object::BasicSymbolRef::Flags Flags) { if (Flags & object::BasicSymbolRef::SF_Undefined) @@ -561,7 +616,7 @@ void llvm::thinLTOInternalizeModule(Module &TheModule, // Automatically import functions in Module \p DestModule based on the summaries // index. // -bool FunctionImporter::importFunctions( +Expected<bool> FunctionImporter::importFunctions( Module &DestModule, const FunctionImporter::ImportMapTy &ImportList, bool ForceImportReferencedDiscardableSymbols) { DEBUG(dbgs() << "Starting import for Module " @@ -579,14 +634,17 @@ bool FunctionImporter::importFunctions( // Get the module for the import const auto &FunctionsToImportPerModule = ImportList.find(Name); assert(FunctionsToImportPerModule != ImportList.end()); - std::unique_ptr<Module> SrcModule = ModuleLoader(Name); + Expected<std::unique_ptr<Module>> SrcModuleOrErr = ModuleLoader(Name); + if (!SrcModuleOrErr) + return SrcModuleOrErr.takeError(); + std::unique_ptr<Module> SrcModule = std::move(*SrcModuleOrErr); assert(&DestModule.getContext() == &SrcModule->getContext() && "Context mismatch"); // If modules were created with lazy metadata loading, materialize it // now, before linking it (otherwise this will be a noop). - SrcModule->materializeMetadata(); - UpgradeDebugInfo(*SrcModule); + if (Error Err = SrcModule->materializeMetadata()) + return std::move(Err); auto &ImportGUIDs = FunctionsToImportPerModule->second; // Find the globals to import @@ -600,7 +658,8 @@ bool FunctionImporter::importFunctions( << " " << F.getName() << " from " << SrcModule->getSourceFileName() << "\n"); if (Import) { - F.materialize(); + if (Error Err = F.materialize()) + return std::move(Err); if (EnableImportMetadata) { // Add 'thinlto_src_module' metadata for statistics and debugging. F.setMetadata( @@ -622,7 +681,8 @@ bool FunctionImporter::importFunctions( << " " << GV.getName() << " from " << SrcModule->getSourceFileName() << "\n"); if (Import) { - GV.materialize(); + if (Error Err = GV.materialize()) + return std::move(Err); GlobalsToImport.insert(&GV); } } @@ -648,13 +708,19 @@ bool FunctionImporter::importFunctions( << " " << GO->getName() << " from " << SrcModule->getSourceFileName() << "\n"); #endif - GO->materialize(); + if (Error Err = GO->materialize()) + return std::move(Err); GlobalsToImport.insert(GO); - GA.materialize(); + if (Error Err = GA.materialize()) + return std::move(Err); GlobalsToImport.insert(&GA); } } + // Upgrade debug info after we're done materializing all the globals and we + // have loaded all the required metadata! + UpgradeDebugInfo(*SrcModule); + // Link in the specified functions. if (renameModuleForThinLTO(*SrcModule, Index, &GlobalsToImport)) return true; @@ -674,9 +740,10 @@ bool FunctionImporter::importFunctions( report_fatal_error("Function Import: link error"); ImportedCount += GlobalsToImport.size(); + NumImportedModules++; } - NumImported += ImportedCount; + NumImportedFunctions += ImportedCount; DEBUG(dbgs() << "Imported " << ImportedCount << " functions for Module " << DestModule.getModuleIdentifier() << "\n"); @@ -689,106 +756,94 @@ static cl::opt<std::string> SummaryFile("summary-file", cl::desc("The summary file to use for function importing.")); -static void diagnosticHandler(const DiagnosticInfo &DI) { - raw_ostream &OS = errs(); - DiagnosticPrinterRawOStream DP(OS); - DI.print(DP); - OS << '\n'; -} +static bool doImportingForModule(Module &M) { + if (SummaryFile.empty()) + report_fatal_error("error: -function-import requires -summary-file\n"); + Expected<std::unique_ptr<ModuleSummaryIndex>> IndexPtrOrErr = + getModuleSummaryIndexForFile(SummaryFile); + if (!IndexPtrOrErr) { + logAllUnhandledErrors(IndexPtrOrErr.takeError(), errs(), + "Error loading file '" + SummaryFile + "': "); + return false; + } + std::unique_ptr<ModuleSummaryIndex> Index = std::move(*IndexPtrOrErr); + + // First step is collecting the import list. + FunctionImporter::ImportMapTy ImportList; + ComputeCrossModuleImportForModule(M.getModuleIdentifier(), *Index, + ImportList); + + // Conservatively mark all internal values as promoted. This interface is + // only used when doing importing via the function importing pass. The pass + // is only enabled when testing importing via the 'opt' tool, which does + // not do the ThinLink that would normally determine what values to promote. + for (auto &I : *Index) { + for (auto &S : I.second) { + if (GlobalValue::isLocalLinkage(S->linkage())) + S->setLinkage(GlobalValue::ExternalLinkage); + } + } -/// Parse the summary index out of an IR file and return the summary -/// index object if found, or nullptr if not. -static std::unique_ptr<ModuleSummaryIndex> getModuleSummaryIndexForFile( - StringRef Path, std::string &Error, - const DiagnosticHandlerFunction &DiagnosticHandler) { - std::unique_ptr<MemoryBuffer> Buffer; - ErrorOr<std::unique_ptr<MemoryBuffer>> BufferOrErr = - MemoryBuffer::getFile(Path); - if (std::error_code EC = BufferOrErr.getError()) { - Error = EC.message(); - return nullptr; + // Next we need to promote to global scope and rename any local values that + // are potentially exported to other modules. + if (renameModuleForThinLTO(M, *Index, nullptr)) { + errs() << "Error renaming module\n"; + return false; } - Buffer = std::move(BufferOrErr.get()); - ErrorOr<std::unique_ptr<object::ModuleSummaryIndexObjectFile>> ObjOrErr = - object::ModuleSummaryIndexObjectFile::create(Buffer->getMemBufferRef(), - DiagnosticHandler); - if (std::error_code EC = ObjOrErr.getError()) { - Error = EC.message(); - return nullptr; + + // Perform the import now. + auto ModuleLoader = [&M](StringRef Identifier) { + return loadFile(Identifier, M.getContext()); + }; + FunctionImporter Importer(*Index, ModuleLoader); + Expected<bool> Result = Importer.importFunctions( + M, ImportList, !DontForceImportReferencedDiscardableSymbols); + + // FIXME: Probably need to propagate Errors through the pass manager. + if (!Result) { + logAllUnhandledErrors(Result.takeError(), errs(), + "Error importing module: "); + return false; } - return (*ObjOrErr)->takeIndex(); + + return *Result; } namespace { /// Pass that performs cross-module function import provided a summary file. -class FunctionImportPass : public ModulePass { - /// Optional module summary index to use for importing, otherwise - /// the summary-file option must be specified. - const ModuleSummaryIndex *Index; - +class FunctionImportLegacyPass : public ModulePass { public: /// Pass identification, replacement for typeid static char ID; /// Specify pass name for debug output - const char *getPassName() const override { return "Function Importing"; } + StringRef getPassName() const override { return "Function Importing"; } - explicit FunctionImportPass(const ModuleSummaryIndex *Index = nullptr) - : ModulePass(ID), Index(Index) {} + explicit FunctionImportLegacyPass() : ModulePass(ID) {} bool runOnModule(Module &M) override { if (skipModule(M)) return false; - if (SummaryFile.empty() && !Index) - report_fatal_error("error: -function-import requires -summary-file or " - "file from frontend\n"); - std::unique_ptr<ModuleSummaryIndex> IndexPtr; - if (!SummaryFile.empty()) { - if (Index) - report_fatal_error("error: -summary-file and index from frontend\n"); - std::string Error; - IndexPtr = - getModuleSummaryIndexForFile(SummaryFile, Error, diagnosticHandler); - if (!IndexPtr) { - errs() << "Error loading file '" << SummaryFile << "': " << Error - << "\n"; - return false; - } - Index = IndexPtr.get(); - } - - // First step is collecting the import list. - FunctionImporter::ImportMapTy ImportList; - ComputeCrossModuleImportForModule(M.getModuleIdentifier(), *Index, - ImportList); - - // Next we need to promote to global scope and rename any local values that - // are potentially exported to other modules. - if (renameModuleForThinLTO(M, *Index, nullptr)) { - errs() << "Error renaming module\n"; - return false; - } - - // Perform the import now. - auto ModuleLoader = [&M](StringRef Identifier) { - return loadFile(Identifier, M.getContext()); - }; - FunctionImporter Importer(*Index, ModuleLoader); - return Importer.importFunctions( - M, ImportList, !DontForceImportReferencedDiscardableSymbols); + return doImportingForModule(M); } }; } // anonymous namespace -char FunctionImportPass::ID = 0; -INITIALIZE_PASS_BEGIN(FunctionImportPass, "function-import", - "Summary Based Function Import", false, false) -INITIALIZE_PASS_END(FunctionImportPass, "function-import", - "Summary Based Function Import", false, false) +PreservedAnalyses FunctionImportPass::run(Module &M, + ModuleAnalysisManager &AM) { + if (!doImportingForModule(M)) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} + +char FunctionImportLegacyPass::ID = 0; +INITIALIZE_PASS(FunctionImportLegacyPass, "function-import", + "Summary Based Function Import", false, false) namespace llvm { -Pass *createFunctionImportPass(const ModuleSummaryIndex *Index = nullptr) { - return new FunctionImportPass(Index); +Pass *createFunctionImportPass() { + return new FunctionImportLegacyPass(); } } diff --git a/contrib/llvm/lib/Transforms/IPO/GlobalDCE.cpp b/contrib/llvm/lib/Transforms/IPO/GlobalDCE.cpp index 4c74698..7a04de3 100644 --- a/contrib/llvm/lib/Transforms/IPO/GlobalDCE.cpp +++ b/contrib/llvm/lib/Transforms/IPO/GlobalDCE.cpp @@ -162,45 +162,29 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &) { GIF.setResolver(nullptr); } - if (!DeadFunctions.empty()) { - // Now that all interferences have been dropped, delete the actual objects - // themselves. - for (Function *F : DeadFunctions) { - RemoveUnusedGlobalValue(*F); - M.getFunctionList().erase(F); - } - NumFunctions += DeadFunctions.size(); + // Now that all interferences have been dropped, delete the actual objects + // themselves. + auto EraseUnusedGlobalValue = [&](GlobalValue *GV) { + RemoveUnusedGlobalValue(*GV); + GV->eraseFromParent(); Changed = true; - } + }; - if (!DeadGlobalVars.empty()) { - for (GlobalVariable *GV : DeadGlobalVars) { - RemoveUnusedGlobalValue(*GV); - M.getGlobalList().erase(GV); - } - NumVariables += DeadGlobalVars.size(); - Changed = true; - } + NumFunctions += DeadFunctions.size(); + for (Function *F : DeadFunctions) + EraseUnusedGlobalValue(F); - // Now delete any dead aliases. - if (!DeadAliases.empty()) { - for (GlobalAlias *GA : DeadAliases) { - RemoveUnusedGlobalValue(*GA); - M.getAliasList().erase(GA); - } - NumAliases += DeadAliases.size(); - Changed = true; - } + NumVariables += DeadGlobalVars.size(); + for (GlobalVariable *GV : DeadGlobalVars) + EraseUnusedGlobalValue(GV); - // Now delete any dead aliases. - if (!DeadIFuncs.empty()) { - for (GlobalIFunc *GIF : DeadIFuncs) { - RemoveUnusedGlobalValue(*GIF); - M.getIFuncList().erase(GIF); - } - NumIFuncs += DeadIFuncs.size(); - Changed = true; - } + NumAliases += DeadAliases.size(); + for (GlobalAlias *GA : DeadAliases) + EraseUnusedGlobalValue(GA); + + NumIFuncs += DeadIFuncs.size(); + for (GlobalIFunc *GIF : DeadIFuncs) + EraseUnusedGlobalValue(GIF); // Make sure that all memory is released AliveGlobals.clear(); diff --git a/contrib/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/contrib/llvm/lib/Transforms/IPO/GlobalOpt.cpp index 99b12d4..5b0d5e3 100644 --- a/contrib/llvm/lib/Transforms/IPO/GlobalOpt.cpp +++ b/contrib/llvm/lib/Transforms/IPO/GlobalOpt.cpp @@ -371,14 +371,14 @@ static bool IsUserOfGlobalSafeForSRA(User *U, GlobalValue *GV) { ++GEPI; // Skip over the pointer index. // If this is a use of an array allocation, do a bit more checking for sanity. - if (ArrayType *AT = dyn_cast<ArrayType>(*GEPI)) { - uint64_t NumElements = AT->getNumElements(); + if (GEPI.isSequential()) { ConstantInt *Idx = cast<ConstantInt>(U->getOperand(2)); // Check to make sure that index falls within the array. If not, // something funny is going on, so we won't do the optimization. // - if (Idx->getZExtValue() >= NumElements) + if (GEPI.isBoundedSequential() && + Idx->getZExtValue() >= GEPI.getSequentialNumElements()) return false; // We cannot scalar repl this level of the array unless any array @@ -391,19 +391,13 @@ static bool IsUserOfGlobalSafeForSRA(User *U, GlobalValue *GV) { for (++GEPI; // Skip array index. GEPI != E; ++GEPI) { - uint64_t NumElements; - if (ArrayType *SubArrayTy = dyn_cast<ArrayType>(*GEPI)) - NumElements = SubArrayTy->getNumElements(); - else if (VectorType *SubVectorTy = dyn_cast<VectorType>(*GEPI)) - NumElements = SubVectorTy->getNumElements(); - else { - assert((*GEPI)->isStructTy() && - "Indexed GEP type is not array, vector, or struct!"); + if (GEPI.isStruct()) continue; - } ConstantInt *IdxVal = dyn_cast<ConstantInt>(GEPI.getOperand()); - if (!IdxVal || IdxVal->getZExtValue() >= NumElements) + if (!IdxVal || + (GEPI.isBoundedSequential() && + IdxVal->getZExtValue() >= GEPI.getSequentialNumElements())) return false; } } @@ -473,12 +467,7 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { NGV->setAlignment(NewAlign); } } else if (SequentialType *STy = dyn_cast<SequentialType>(Ty)) { - unsigned NumElements = 0; - if (ArrayType *ATy = dyn_cast<ArrayType>(STy)) - NumElements = ATy->getNumElements(); - else - NumElements = cast<VectorType>(STy)->getNumElements(); - + unsigned NumElements = STy->getNumElements(); if (NumElements > 16 && GV->hasNUsesOrMore(16)) return nullptr; // It's not worth it. NewGlobals.reserve(NumElements); @@ -1653,7 +1642,7 @@ static bool deleteIfDead(GlobalValue &GV, SmallSet<const Comdat *, 8> &NotDiscardableComdats) { GV.removeDeadConstantUsers(); - if (!GV.isDiscardableIfUnused()) + if (!GV.isDiscardableIfUnused() && !GV.isDeclaration()) return false; if (const Comdat *C = GV.getComdat()) @@ -1662,7 +1651,7 @@ static bool deleteIfDead(GlobalValue &GV, bool Dead; if (auto *F = dyn_cast<Function>(&GV)) - Dead = F->isDefTriviallyDead(); + Dead = (F->isDeclaration() && F->use_empty()) || F->isDefTriviallyDead(); else Dead = GV.use_empty(); if (!Dead) @@ -1737,7 +1726,7 @@ static bool isPointerValueDeadOnEntryToFunction( for (auto *L : Loads) { auto *LTy = L->getType(); - if (!std::any_of(Stores.begin(), Stores.end(), [&](StoreInst *S) { + if (none_of(Stores, [&](const StoreInst *S) { auto *STy = S->getValueOperand()->getType(); // The load is only dominated by the store if DomTree says so // and the number of bits loaded in L is less than or equal to @@ -2079,10 +2068,10 @@ OptimizeGlobalVars(Module &M, TargetLibraryInfo *TLI, GV->setLinkage(GlobalValue::InternalLinkage); // Simplify the initializer. if (GV->hasInitializer()) - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(GV->getInitializer())) { + if (auto *C = dyn_cast<Constant>(GV->getInitializer())) { auto &DL = M.getDataLayout(); - Constant *New = ConstantFoldConstantExpression(CE, DL, TLI); - if (New && New != CE) + Constant *New = ConstantFoldConstant(C, DL, TLI); + if (New && New != C) GV->setInitializer(New); } @@ -2125,12 +2114,7 @@ static Constant *EvaluateStoreInto(Constant *Init, Constant *Val, ConstantInt *CI = cast<ConstantInt>(Addr->getOperand(OpNo)); SequentialType *InitTy = cast<SequentialType>(Init->getType()); - - uint64_t NumElts; - if (ArrayType *ATy = dyn_cast<ArrayType>(InitTy)) - NumElts = ATy->getNumElements(); - else - NumElts = InitTy->getVectorNumElements(); + uint64_t NumElts = InitTy->getNumElements(); // Break up the array into elements. for (uint64_t i = 0, e = NumElts; i != e; ++i) @@ -2565,7 +2549,7 @@ static bool optimizeGlobalsInModule( return Changed; } -PreservedAnalyses GlobalOptPass::run(Module &M, AnalysisManager<Module> &AM) { +PreservedAnalyses GlobalOptPass::run(Module &M, ModuleAnalysisManager &AM) { auto &DL = M.getDataLayout(); auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); auto &FAM = diff --git a/contrib/llvm/lib/Transforms/IPO/GlobalSplit.cpp b/contrib/llvm/lib/Transforms/IPO/GlobalSplit.cpp new file mode 100644 index 0000000..bbbd096 --- /dev/null +++ b/contrib/llvm/lib/Transforms/IPO/GlobalSplit.cpp @@ -0,0 +1,171 @@ +//===- GlobalSplit.cpp - global variable splitter -------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass uses inrange annotations on GEP indices to split globals where +// beneficial. Clang currently attaches these annotations to references to +// virtual table globals under the Itanium ABI for the benefit of the +// whole-program virtual call optimization and control flow integrity passes. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/GlobalSplit.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/Pass.h" + +#include <set> + +using namespace llvm; + +namespace { + +bool splitGlobal(GlobalVariable &GV) { + // If the address of the global is taken outside of the module, we cannot + // apply this transformation. + if (!GV.hasLocalLinkage()) + return false; + + // We currently only know how to split ConstantStructs. + auto *Init = dyn_cast_or_null<ConstantStruct>(GV.getInitializer()); + if (!Init) + return false; + + // Verify that each user of the global is an inrange getelementptr constant. + // From this it follows that any loads from or stores to that global must use + // a pointer derived from an inrange getelementptr constant, which is + // sufficient to allow us to apply the splitting transform. + for (User *U : GV.users()) { + if (!isa<Constant>(U)) + return false; + + auto *GEP = dyn_cast<GEPOperator>(U); + if (!GEP || !GEP->getInRangeIndex() || *GEP->getInRangeIndex() != 1 || + !isa<ConstantInt>(GEP->getOperand(1)) || + !cast<ConstantInt>(GEP->getOperand(1))->isZero() || + !isa<ConstantInt>(GEP->getOperand(2))) + return false; + } + + SmallVector<MDNode *, 2> Types; + GV.getMetadata(LLVMContext::MD_type, Types); + + const DataLayout &DL = GV.getParent()->getDataLayout(); + const StructLayout *SL = DL.getStructLayout(Init->getType()); + + IntegerType *Int32Ty = Type::getInt32Ty(GV.getContext()); + + std::vector<GlobalVariable *> SplitGlobals(Init->getNumOperands()); + for (unsigned I = 0; I != Init->getNumOperands(); ++I) { + // Build a global representing this split piece. + auto *SplitGV = + new GlobalVariable(*GV.getParent(), Init->getOperand(I)->getType(), + GV.isConstant(), GlobalValue::PrivateLinkage, + Init->getOperand(I), GV.getName() + "." + utostr(I)); + SplitGlobals[I] = SplitGV; + + unsigned SplitBegin = SL->getElementOffset(I); + unsigned SplitEnd = (I == Init->getNumOperands() - 1) + ? SL->getSizeInBytes() + : SL->getElementOffset(I + 1); + + // Rebuild type metadata, adjusting by the split offset. + // FIXME: See if we can use DW_OP_piece to preserve debug metadata here. + for (MDNode *Type : Types) { + uint64_t ByteOffset = cast<ConstantInt>( + cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) + ->getZExtValue(); + if (ByteOffset < SplitBegin || ByteOffset >= SplitEnd) + continue; + SplitGV->addMetadata( + LLVMContext::MD_type, + *MDNode::get(GV.getContext(), + {ConstantAsMetadata::get( + ConstantInt::get(Int32Ty, ByteOffset - SplitBegin)), + Type->getOperand(1)})); + } + } + + for (User *U : GV.users()) { + auto *GEP = cast<GEPOperator>(U); + unsigned I = cast<ConstantInt>(GEP->getOperand(2))->getZExtValue(); + if (I >= SplitGlobals.size()) + continue; + + SmallVector<Value *, 4> Ops; + Ops.push_back(ConstantInt::get(Int32Ty, 0)); + for (unsigned I = 3; I != GEP->getNumOperands(); ++I) + Ops.push_back(GEP->getOperand(I)); + + auto *NewGEP = ConstantExpr::getGetElementPtr( + SplitGlobals[I]->getInitializer()->getType(), SplitGlobals[I], Ops, + GEP->isInBounds()); + GEP->replaceAllUsesWith(NewGEP); + } + + // Finally, remove the original global. Any remaining uses refer to invalid + // elements of the global, so replace with undef. + if (!GV.use_empty()) + GV.replaceAllUsesWith(UndefValue::get(GV.getType())); + GV.eraseFromParent(); + return true; +} + +bool splitGlobals(Module &M) { + // First, see if the module uses either of the llvm.type.test or + // llvm.type.checked.load intrinsics, which indicates that splitting globals + // may be beneficial. + Function *TypeTestFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_test)); + Function *TypeCheckedLoadFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); + if ((!TypeTestFunc || TypeTestFunc->use_empty()) && + (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty())) + return false; + + bool Changed = false; + for (auto I = M.global_begin(); I != M.global_end();) { + GlobalVariable &GV = *I; + ++I; + Changed |= splitGlobal(GV); + } + return Changed; +} + +struct GlobalSplit : public ModulePass { + static char ID; + GlobalSplit() : ModulePass(ID) { + initializeGlobalSplitPass(*PassRegistry::getPassRegistry()); + } + bool runOnModule(Module &M) { + if (skipModule(M)) + return false; + + return splitGlobals(M); + } +}; + +} + +INITIALIZE_PASS(GlobalSplit, "globalsplit", "Global splitter", false, false) +char GlobalSplit::ID = 0; + +ModulePass *llvm::createGlobalSplitPass() { + return new GlobalSplit; +} + +PreservedAnalyses GlobalSplitPass::run(Module &M, ModuleAnalysisManager &AM) { + if (!splitGlobals(M)) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); +} diff --git a/contrib/llvm/lib/Transforms/IPO/IPO.cpp b/contrib/llvm/lib/Transforms/IPO/IPO.cpp index 3507eba..89518f3 100644 --- a/contrib/llvm/lib/Transforms/IPO/IPO.cpp +++ b/contrib/llvm/lib/Transforms/IPO/IPO.cpp @@ -18,6 +18,7 @@ #include "llvm/InitializePasses.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" using namespace llvm; @@ -31,8 +32,9 @@ void llvm::initializeIPO(PassRegistry &Registry) { initializeForceFunctionAttrsLegacyPassPass(Registry); initializeGlobalDCELegacyPassPass(Registry); initializeGlobalOptLegacyPassPass(Registry); + initializeGlobalSplitPass(Registry); initializeIPCPPass(Registry); - initializeAlwaysInlinerPass(Registry); + initializeAlwaysInlinerLegacyPassPass(Registry); initializeSimpleInlinerPass(Registry); initializeInferFunctionAttrsLegacyPassPass(Registry); initializeInternalizeLegacyPassPass(Registry); @@ -53,7 +55,7 @@ void llvm::initializeIPO(PassRegistry &Registry) { initializeBarrierNoopPass(Registry); initializeEliminateAvailableExternallyLegacyPassPass(Registry); initializeSampleProfileLoaderLegacyPassPass(Registry); - initializeFunctionImportPassPass(Registry); + initializeFunctionImportLegacyPassPass(Registry); initializeWholeProgramDevirtPass(Registry); } @@ -82,7 +84,7 @@ void LLVMAddFunctionInliningPass(LLVMPassManagerRef PM) { } void LLVMAddAlwaysInlinerPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(llvm::createAlwaysInlinerPass()); + unwrap(PM)->add(llvm::createAlwaysInlinerLegacyPass()); } void LLVMAddGlobalDCEPass(LLVMPassManagerRef PM) { diff --git a/contrib/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp b/contrib/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp index ab2d2bd..2ef299d 100644 --- a/contrib/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp +++ b/contrib/llvm/lib/Transforms/IPO/InferFunctionAttrs.cpp @@ -34,7 +34,7 @@ static bool inferAllPrototypeAttributes(Module &M, } PreservedAnalyses InferFunctionAttrsPass::run(Module &M, - AnalysisManager<Module> &AM) { + ModuleAnalysisManager &AM) { auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); if (!inferAllPrototypeAttributes(M, TLI)) diff --git a/contrib/llvm/lib/Transforms/IPO/InlineAlways.cpp b/contrib/llvm/lib/Transforms/IPO/InlineAlways.cpp deleted file mode 100644 index cb1ab95..0000000 --- a/contrib/llvm/lib/Transforms/IPO/InlineAlways.cpp +++ /dev/null @@ -1,103 +0,0 @@ -//===- InlineAlways.cpp - Code to inline always_inline functions ----------===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -// -// This file implements a custom inliner that handles only functions that -// are marked as "always inline". -// -//===----------------------------------------------------------------------===// - -#include "llvm/Transforms/IPO.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/CallGraph.h" -#include "llvm/Analysis/InlineCost.h" -#include "llvm/Analysis/ProfileSummaryInfo.h" -#include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/IR/CallSite.h" -#include "llvm/IR/CallingConv.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/Transforms/IPO/InlinerPass.h" - -using namespace llvm; - -#define DEBUG_TYPE "inline" - -namespace { - -/// \brief Inliner pass which only handles "always inline" functions. -class AlwaysInliner : public Inliner { - -public: - AlwaysInliner() : Inliner(ID, /*InsertLifetime*/ true) { - initializeAlwaysInlinerPass(*PassRegistry::getPassRegistry()); - } - - AlwaysInliner(bool InsertLifetime) : Inliner(ID, InsertLifetime) { - initializeAlwaysInlinerPass(*PassRegistry::getPassRegistry()); - } - - /// Main run interface method. We override here to avoid calling skipSCC(). - bool runOnSCC(CallGraphSCC &SCC) override { return inlineCalls(SCC); } - - static char ID; // Pass identification, replacement for typeid - - InlineCost getInlineCost(CallSite CS) override; - - using llvm::Pass::doFinalization; - bool doFinalization(CallGraph &CG) override { - return removeDeadFunctions(CG, /*AlwaysInlineOnly=*/ true); - } -}; - -} - -char AlwaysInliner::ID = 0; -INITIALIZE_PASS_BEGIN(AlwaysInliner, "always-inline", - "Inliner for always_inline functions", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(AlwaysInliner, "always-inline", - "Inliner for always_inline functions", false, false) - -Pass *llvm::createAlwaysInlinerPass() { return new AlwaysInliner(); } - -Pass *llvm::createAlwaysInlinerPass(bool InsertLifetime) { - return new AlwaysInliner(InsertLifetime); -} - -/// \brief Get the inline cost for the always-inliner. -/// -/// The always inliner *only* handles functions which are marked with the -/// attribute to force inlining. As such, it is dramatically simpler and avoids -/// using the powerful (but expensive) inline cost analysis. Instead it uses -/// a very simple and boring direct walk of the instructions looking for -/// impossible-to-inline constructs. -/// -/// Note, it would be possible to go to some lengths to cache the information -/// computed here, but as we only expect to do this for relatively few and -/// small functions which have the explicit attribute to force inlining, it is -/// likely not worth it in practice. -InlineCost AlwaysInliner::getInlineCost(CallSite CS) { - Function *Callee = CS.getCalledFunction(); - - // Only inline direct calls to functions with always-inline attributes - // that are viable for inlining. FIXME: We shouldn't even get here for - // declarations. - if (Callee && !Callee->isDeclaration() && - CS.hasFnAttr(Attribute::AlwaysInline) && isInlineViable(*Callee)) - return InlineCost::getAlways(); - - return InlineCost::getNever(); -} diff --git a/contrib/llvm/lib/Transforms/IPO/InlineSimple.cpp b/contrib/llvm/lib/Transforms/IPO/InlineSimple.cpp index 2aa650b..1770445 100644 --- a/contrib/llvm/lib/Transforms/IPO/InlineSimple.cpp +++ b/contrib/llvm/lib/Transforms/IPO/InlineSimple.cpp @@ -25,7 +25,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Transforms/IPO.h" -#include "llvm/Transforms/IPO/InlinerPass.h" +#include "llvm/Transforms/IPO/Inliner.h" using namespace llvm; @@ -38,21 +38,17 @@ namespace { /// The common implementation of the inlining logic is shared between this /// inliner pass and the always inliner pass. The two passes use different cost /// analyses to determine when to inline. -class SimpleInliner : public Inliner { - // This field is populated based on one of the following: - // * optimization or size-optimization levels, - // * the --inline-threshold flag, or - // * a user specified value. - int DefaultThreshold; +class SimpleInliner : public LegacyInlinerBase { + + InlineParams Params; public: - SimpleInliner() - : Inliner(ID), DefaultThreshold(llvm::getDefaultInlineThreshold()) { + SimpleInliner() : LegacyInlinerBase(ID), Params(llvm::getInlineParams()) { initializeSimpleInlinerPass(*PassRegistry::getPassRegistry()); } - explicit SimpleInliner(int Threshold) - : Inliner(ID), DefaultThreshold(Threshold) { + explicit SimpleInliner(InlineParams Params) + : LegacyInlinerBase(ID), Params(Params) { initializeSimpleInlinerPass(*PassRegistry::getPassRegistry()); } @@ -61,7 +57,11 @@ public: InlineCost getInlineCost(CallSite CS) override { Function *Callee = CS.getCalledFunction(); TargetTransformInfo &TTI = TTIWP->getTTI(*Callee); - return llvm::getInlineCost(CS, DefaultThreshold, TTI, ACT, PSI); + std::function<AssumptionCache &(Function &)> GetAssumptionCache = + [&](Function &F) -> AssumptionCache & { + return ACT->getAssumptionCache(F); + }; + return llvm::getInlineCost(CS, Params, TTI, GetAssumptionCache, PSI); } bool runOnSCC(CallGraphSCC &SCC) override; @@ -69,39 +69,43 @@ public: private: TargetTransformInfoWrapperPass *TTIWP; + }; } // end anonymous namespace char SimpleInliner::ID = 0; -INITIALIZE_PASS_BEGIN(SimpleInliner, "inline", - "Function Integration/Inlining", false, false) +INITIALIZE_PASS_BEGIN(SimpleInliner, "inline", "Function Integration/Inlining", + false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(SimpleInliner, "inline", - "Function Integration/Inlining", false, false) +INITIALIZE_PASS_END(SimpleInliner, "inline", "Function Integration/Inlining", + false, false) Pass *llvm::createFunctionInliningPass() { return new SimpleInliner(); } Pass *llvm::createFunctionInliningPass(int Threshold) { - return new SimpleInliner(Threshold); + return new SimpleInliner(llvm::getInlineParams(Threshold)); } Pass *llvm::createFunctionInliningPass(unsigned OptLevel, unsigned SizeOptLevel) { - return new SimpleInliner( - llvm::computeThresholdFromOptLevels(OptLevel, SizeOptLevel)); + return new SimpleInliner(llvm::getInlineParams(OptLevel, SizeOptLevel)); +} + +Pass *llvm::createFunctionInliningPass(InlineParams &Params) { + return new SimpleInliner(Params); } bool SimpleInliner::runOnSCC(CallGraphSCC &SCC) { TTIWP = &getAnalysis<TargetTransformInfoWrapperPass>(); - return Inliner::runOnSCC(SCC); + return LegacyInlinerBase::runOnSCC(SCC); } void SimpleInliner::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<TargetTransformInfoWrapperPass>(); - Inliner::getAnalysisUsage(AU); + LegacyInlinerBase::getAnalysisUsage(AU); } diff --git a/contrib/llvm/lib/Transforms/IPO/Inliner.cpp b/contrib/llvm/lib/Transforms/IPO/Inliner.cpp index 79535ca..3f4731c 100644 --- a/contrib/llvm/lib/Transforms/IPO/Inliner.cpp +++ b/contrib/llvm/lib/Transforms/IPO/Inliner.cpp @@ -13,6 +13,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/IPO/Inliner.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -20,19 +21,21 @@ #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/IPO/InlinerPass.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" using namespace llvm; #define DEBUG_TYPE "inline" @@ -47,15 +50,44 @@ STATISTIC(NumMergedAllocas, "Number of allocas merged together"); // if those would be more profitable and blocked inline steps. STATISTIC(NumCallerCallersAnalyzed, "Number of caller-callers analyzed"); -Inliner::Inliner(char &ID) : CallGraphSCCPass(ID), InsertLifetime(true) {} - -Inliner::Inliner(char &ID, bool InsertLifetime) +/// Flag to disable manual alloca merging. +/// +/// Merging of allocas was originally done as a stack-size saving technique +/// prior to LLVM's code generator having support for stack coloring based on +/// lifetime markers. It is now in the process of being removed. To experiment +/// with disabling it and relying fully on lifetime marker based stack +/// coloring, you can pass this flag to LLVM. +static cl::opt<bool> + DisableInlinedAllocaMerging("disable-inlined-alloca-merging", + cl::init(false), cl::Hidden); + +namespace { +enum class InlinerFunctionImportStatsOpts { + No = 0, + Basic = 1, + Verbose = 2, +}; + +cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats( + "inliner-function-import-stats", + cl::init(InlinerFunctionImportStatsOpts::No), + cl::values(clEnumValN(InlinerFunctionImportStatsOpts::Basic, "basic", + "basic statistics"), + clEnumValN(InlinerFunctionImportStatsOpts::Verbose, "verbose", + "printing of statistics for each inlined function")), + cl::Hidden, cl::desc("Enable inliner stats for imported functions")); +} // namespace + +LegacyInlinerBase::LegacyInlinerBase(char &ID) + : CallGraphSCCPass(ID), InsertLifetime(true) {} + +LegacyInlinerBase::LegacyInlinerBase(char &ID, bool InsertLifetime) : CallGraphSCCPass(ID), InsertLifetime(InsertLifetime) {} /// For this class, we declare that we require and preserve the call graph. /// If the derived class implements this method, it should /// always explicitly call the implementation here. -void Inliner::getAnalysisUsage(AnalysisUsage &AU) const { +void LegacyInlinerBase::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<ProfileSummaryInfoWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); @@ -63,62 +95,33 @@ void Inliner::getAnalysisUsage(AnalysisUsage &AU) const { CallGraphSCCPass::getAnalysisUsage(AU); } +typedef DenseMap<ArrayType *, std::vector<AllocaInst *>> InlinedArrayAllocasTy; -typedef DenseMap<ArrayType*, std::vector<AllocaInst*> > -InlinedArrayAllocasTy; - -/// If it is possible to inline the specified call site, -/// do so and update the CallGraph for this operation. +/// Look at all of the allocas that we inlined through this call site. If we +/// have already inlined other allocas through other calls into this function, +/// then we know that they have disjoint lifetimes and that we can merge them. /// -/// This function also does some basic book-keeping to update the IR. The -/// InlinedArrayAllocas map keeps track of any allocas that are already -/// available from other functions inlined into the caller. If we are able to -/// inline this call site we attempt to reuse already available allocas or add -/// any new allocas to the set if not possible. -static bool InlineCallIfPossible(Pass &P, CallSite CS, InlineFunctionInfo &IFI, - InlinedArrayAllocasTy &InlinedArrayAllocas, - int InlineHistory, bool InsertLifetime) { - Function *Callee = CS.getCalledFunction(); - Function *Caller = CS.getCaller(); - - // We need to manually construct BasicAA directly in order to disable - // its use of other function analyses. - BasicAAResult BAR(createLegacyPMBasicAAResult(P, *Callee)); - - // Construct our own AA results for this function. We do this manually to - // work around the limitations of the legacy pass manager. - AAResults AAR(createLegacyPMAAResults(P, *Callee, BAR)); - - // Try to inline the function. Get the list of static allocas that were - // inlined. - if (!InlineFunction(CS, IFI, &AAR, InsertLifetime)) - return false; - - AttributeFuncs::mergeAttributesForInlining(*Caller, *Callee); +/// There are many heuristics possible for merging these allocas, and the +/// different options have different tradeoffs. One thing that we *really* +/// don't want to hurt is SRoA: once inlining happens, often allocas are no +/// longer address taken and so they can be promoted. +/// +/// Our "solution" for that is to only merge allocas whose outermost type is an +/// array type. These are usually not promoted because someone is using a +/// variable index into them. These are also often the most important ones to +/// merge. +/// +/// A better solution would be to have real memory lifetime markers in the IR +/// and not have the inliner do any merging of allocas at all. This would +/// allow the backend to do proper stack slot coloring of all allocas that +/// *actually make it to the backend*, which is really what we want. +/// +/// Because we don't have this information, we do this simple and useful hack. +static void mergeInlinedArrayAllocas( + Function *Caller, InlineFunctionInfo &IFI, + InlinedArrayAllocasTy &InlinedArrayAllocas, int InlineHistory) { + SmallPtrSet<AllocaInst *, 16> UsedAllocas; - // Look at all of the allocas that we inlined through this call site. If we - // have already inlined other allocas through other calls into this function, - // then we know that they have disjoint lifetimes and that we can merge them. - // - // There are many heuristics possible for merging these allocas, and the - // different options have different tradeoffs. One thing that we *really* - // don't want to hurt is SRoA: once inlining happens, often allocas are no - // longer address taken and so they can be promoted. - // - // Our "solution" for that is to only merge allocas whose outermost type is an - // array type. These are usually not promoted because someone is using a - // variable index into them. These are also often the most important ones to - // merge. - // - // A better solution would be to have real memory lifetime markers in the IR - // and not have the inliner do any merging of allocas at all. This would - // allow the backend to do proper stack slot coloring of all allocas that - // *actually make it to the backend*, which is really what we want. - // - // Because we don't have this information, we do this simple and useful hack. - // - SmallPtrSet<AllocaInst*, 16> UsedAllocas; - // When processing our SCC, check to see if CS was inlined from some other // call site. For example, if we're processing "A" in this code: // A() { B() } @@ -131,25 +134,25 @@ static bool InlineCallIfPossible(Pass &P, CallSite CS, InlineFunctionInfo &IFI, // because their scopes are not disjoint. We could make this smarter by // keeping track of the inline history for each alloca in the // InlinedArrayAllocas but this isn't likely to be a significant win. - if (InlineHistory != -1) // Only do merging for top-level call sites in SCC. - return true; - + if (InlineHistory != -1) // Only do merging for top-level call sites in SCC. + return; + // Loop over all the allocas we have so far and see if they can be merged with // a previously inlined alloca. If not, remember that we had it. - for (unsigned AllocaNo = 0, e = IFI.StaticAllocas.size(); - AllocaNo != e; ++AllocaNo) { + for (unsigned AllocaNo = 0, e = IFI.StaticAllocas.size(); AllocaNo != e; + ++AllocaNo) { AllocaInst *AI = IFI.StaticAllocas[AllocaNo]; - + // Don't bother trying to merge array allocations (they will usually be // canonicalized to be an allocation *of* an array), or allocations whose // type is not itself an array (because we're afraid of pessimizing SRoA). ArrayType *ATy = dyn_cast<ArrayType>(AI->getAllocatedType()); if (!ATy || AI->isArrayAllocation()) continue; - + // Get the list of all available allocas for this array type. - std::vector<AllocaInst*> &AllocasForType = InlinedArrayAllocas[ATy]; - + std::vector<AllocaInst *> &AllocasForType = InlinedArrayAllocas[ATy]; + // Loop over the allocas in AllocasForType to see if we can reuse one. Note // that we have to be careful not to reuse the same "available" alloca for // multiple different allocas that we just inlined, we use the 'UsedAllocas' @@ -160,24 +163,24 @@ static bool InlineCallIfPossible(Pass &P, CallSite CS, InlineFunctionInfo &IFI, unsigned Align1 = AI->getAlignment(), Align2 = AvailableAlloca->getAlignment(); - + // The available alloca has to be in the right function, not in some other // function in this SCC. if (AvailableAlloca->getParent() != AI->getParent()) continue; - + // If the inlined function already uses this alloca then we can't reuse // it. if (!UsedAllocas.insert(AvailableAlloca).second) continue; - + // Otherwise, we *can* reuse it, RAUW AI into AvailableAlloca and declare // success! - DEBUG(dbgs() << " ***MERGED ALLOCA: " << *AI << "\n\t\tINTO: " - << *AvailableAlloca << '\n'); - + DEBUG(dbgs() << " ***MERGED ALLOCA: " << *AI + << "\n\t\tINTO: " << *AvailableAlloca << '\n'); + // Move affected dbg.declare calls immediately after the new alloca to - // avoid the situation when a dbg.declare preceeds its alloca. + // avoid the situation when a dbg.declare precedes its alloca. if (auto *L = LocalAsMetadata::getIfExists(AI)) if (auto *MDV = MetadataAsValue::getIfExists(AI->getContext(), L)) for (User *U : MDV->users()) @@ -209,7 +212,7 @@ static bool InlineCallIfPossible(Pass &P, CallSite CS, InlineFunctionInfo &IFI, // If we already nuked the alloca, we're done with it. if (MergedAwayAlloca) continue; - + // If we were unable to merge away the alloca either because there are no // allocas of the right type available or because we reused them all // already, remember that this alloca came from an inlined function and mark @@ -218,19 +221,51 @@ static bool InlineCallIfPossible(Pass &P, CallSite CS, InlineFunctionInfo &IFI, AllocasForType.push_back(AI); UsedAllocas.insert(AI); } - - return true; } -static void emitAnalysis(CallSite CS, const Twine &Msg) { +/// If it is possible to inline the specified call site, +/// do so and update the CallGraph for this operation. +/// +/// This function also does some basic book-keeping to update the IR. The +/// InlinedArrayAllocas map keeps track of any allocas that are already +/// available from other functions inlined into the caller. If we are able to +/// inline this call site we attempt to reuse already available allocas or add +/// any new allocas to the set if not possible. +static bool InlineCallIfPossible( + CallSite CS, InlineFunctionInfo &IFI, + InlinedArrayAllocasTy &InlinedArrayAllocas, int InlineHistory, + bool InsertLifetime, function_ref<AAResults &(Function &)> &AARGetter, + ImportedFunctionsInliningStatistics &ImportedFunctionsStats) { + Function *Callee = CS.getCalledFunction(); Function *Caller = CS.getCaller(); - LLVMContext &Ctx = Caller->getContext(); - DebugLoc DLoc = CS.getInstruction()->getDebugLoc(); - emitOptimizationRemarkAnalysis(Ctx, DEBUG_TYPE, *Caller, DLoc, Msg); + + AAResults &AAR = AARGetter(*Callee); + + // Try to inline the function. Get the list of static allocas that were + // inlined. + if (!InlineFunction(CS, IFI, &AAR, InsertLifetime)) + return false; + + if (InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No) + ImportedFunctionsStats.recordInline(*Caller, *Callee); + + AttributeFuncs::mergeAttributesForInlining(*Caller, *Callee); + + if (!DisableInlinedAllocaMerging) + mergeInlinedArrayAllocas(Caller, IFI, InlinedArrayAllocas, InlineHistory); + + return true; } -bool Inliner::shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, - int &TotalSecondaryCost) { +/// Return true if inlining of CS can block the caller from being +/// inlined which is proved to be more beneficial. \p IC is the +/// estimated inline cost associated with callsite \p CS. +/// \p TotalAltCost will be set to the estimated cost of inlining the caller +/// if \p CS is suppressed for inlining. +static bool +shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, + int &TotalSecondaryCost, + function_ref<InlineCost(CallSite CS)> GetInlineCost) { // For now we only handle local or inline functions. if (!Caller->hasLocalLinkage() && !Caller->hasLinkOnceODRLinkage()) @@ -269,7 +304,7 @@ bool Inliner::shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, continue; } - InlineCost IC2 = getInlineCost(CS2); + InlineCost IC2 = GetInlineCost(CS2); ++NumCallerCallersAnalyzed; if (!IC2) { callerWillBeRemoved = false; @@ -278,7 +313,7 @@ bool Inliner::shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, if (IC2.isAlways()) continue; - // See if inlining or original callsite would erase the cost delta of + // See if inlining of the original callsite would erase the cost delta of // this callsite. We subtract off the penalty for the call instruction, // which we would be deleting. if (IC2.getCostDelta() <= CandidateCost) { @@ -291,7 +326,7 @@ bool Inliner::shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, // be removed entirely. We did not account for this above unless there // is only one caller of Caller. if (callerWillBeRemoved && !Caller->use_empty()) - TotalSecondaryCost += InlineConstants::LastCallToStaticBonus; + TotalSecondaryCost -= InlineConstants::LastCallToStaticBonus; if (inliningPreventsSomeOuterInline && TotalSecondaryCost < IC.getCost()) return true; @@ -300,63 +335,73 @@ bool Inliner::shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, } /// Return true if the inliner should attempt to inline at the given CallSite. -bool Inliner::shouldInline(CallSite CS) { - InlineCost IC = getInlineCost(CS); - +static bool shouldInline(CallSite CS, + function_ref<InlineCost(CallSite CS)> GetInlineCost, + OptimizationRemarkEmitter &ORE) { + using namespace ore; + InlineCost IC = GetInlineCost(CS); + Instruction *Call = CS.getInstruction(); + Function *Callee = CS.getCalledFunction(); + if (IC.isAlways()) { DEBUG(dbgs() << " Inlining: cost=always" - << ", Call: " << *CS.getInstruction() << "\n"); - emitAnalysis(CS, Twine(CS.getCalledFunction()->getName()) + - " should always be inlined (cost=always)"); + << ", Call: " << *CS.getInstruction() << "\n"); + ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "AlwaysInline", Call) + << NV("Callee", Callee) + << " should always be inlined (cost=always)"); return true; } - + if (IC.isNever()) { DEBUG(dbgs() << " NOT Inlining: cost=never" - << ", Call: " << *CS.getInstruction() << "\n"); - emitAnalysis(CS, Twine(CS.getCalledFunction()->getName() + - " should never be inlined (cost=never)")); + << ", Call: " << *CS.getInstruction() << "\n"); + ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "NeverInline", Call) + << NV("Callee", Callee) + << " should never be inlined (cost=never)"); return false; } - + Function *Caller = CS.getCaller(); if (!IC) { DEBUG(dbgs() << " NOT Inlining: cost=" << IC.getCost() - << ", thres=" << (IC.getCostDelta() + IC.getCost()) - << ", Call: " << *CS.getInstruction() << "\n"); - emitAnalysis(CS, Twine(CS.getCalledFunction()->getName() + - " too costly to inline (cost=") + - Twine(IC.getCost()) + ", threshold=" + - Twine(IC.getCostDelta() + IC.getCost()) + ")"); + << ", thres=" << (IC.getCostDelta() + IC.getCost()) + << ", Call: " << *CS.getInstruction() << "\n"); + ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "TooCostly", Call) + << NV("Callee", Callee) << " too costly to inline (cost=" + << NV("Cost", IC.getCost()) << ", threshold=" + << NV("Threshold", IC.getCostDelta() + IC.getCost()) << ")"); return false; } int TotalSecondaryCost = 0; - if (shouldBeDeferred(Caller, CS, IC, TotalSecondaryCost)) { + if (shouldBeDeferred(Caller, CS, IC, TotalSecondaryCost, GetInlineCost)) { DEBUG(dbgs() << " NOT Inlining: " << *CS.getInstruction() - << " Cost = " << IC.getCost() - << ", outer Cost = " << TotalSecondaryCost << '\n'); - emitAnalysis(CS, Twine("Not inlining. Cost of inlining " + - CS.getCalledFunction()->getName() + - " increases the cost of inlining " + - CS.getCaller()->getName() + " in other contexts")); + << " Cost = " << IC.getCost() + << ", outer Cost = " << TotalSecondaryCost << '\n'); + ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, + "IncreaseCostInOtherContexts", Call) + << "Not inlining. Cost of inlining " << NV("Callee", Callee) + << " increases the cost of inlining " << NV("Caller", Caller) + << " in other contexts"); return false; } DEBUG(dbgs() << " Inlining: cost=" << IC.getCost() - << ", thres=" << (IC.getCostDelta() + IC.getCost()) - << ", Call: " << *CS.getInstruction() << '\n'); - emitAnalysis( - CS, CS.getCalledFunction()->getName() + Twine(" can be inlined into ") + - CS.getCaller()->getName() + " with cost=" + Twine(IC.getCost()) + - " (threshold=" + Twine(IC.getCostDelta() + IC.getCost()) + ")"); + << ", thres=" << (IC.getCostDelta() + IC.getCost()) + << ", Call: " << *CS.getInstruction() << '\n'); + ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "CanBeInlined", Call) + << NV("Callee", Callee) << " can be inlined into " + << NV("Caller", Caller) << " with cost=" << NV("Cost", IC.getCost()) + << " (threshold=" + << NV("Threshold", IC.getCostDelta() + IC.getCost()) << ")"); return true; } /// Return true if the specified inline history ID /// indicates an inline history that includes the specified function. -static bool InlineHistoryIncludes(Function *F, int InlineHistoryID, - const SmallVectorImpl<std::pair<Function*, int> > &InlineHistory) { +static bool InlineHistoryIncludes( + Function *F, int InlineHistoryID, + const SmallVectorImpl<std::pair<Function *, int>> &InlineHistory) { while (InlineHistoryID != -1) { assert(unsigned(InlineHistoryID) < InlineHistory.size() && "Invalid inline history ID"); @@ -367,23 +412,32 @@ static bool InlineHistoryIncludes(Function *F, int InlineHistoryID, return false; } -bool Inliner::runOnSCC(CallGraphSCC &SCC) { +bool LegacyInlinerBase::doInitialization(CallGraph &CG) { + if (InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No) + ImportedFunctionsStats.setModuleInfo(CG.getModule()); + return false; // No changes to CallGraph. +} + +bool LegacyInlinerBase::runOnSCC(CallGraphSCC &SCC) { if (skipSCC(SCC)) return false; return inlineCalls(SCC); } -bool Inliner::inlineCalls(CallGraphSCC &SCC) { - CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); - ACT = &getAnalysis<AssumptionCacheTracker>(); - PSI = getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(CG.getModule()); - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - - SmallPtrSet<Function*, 8> SCCFunctions; +static bool +inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, + std::function<AssumptionCache &(Function &)> GetAssumptionCache, + ProfileSummaryInfo *PSI, TargetLibraryInfo &TLI, + bool InsertLifetime, + function_ref<InlineCost(CallSite CS)> GetInlineCost, + function_ref<AAResults &(Function &)> AARGetter, + ImportedFunctionsInliningStatistics &ImportedFunctionsStats) { + SmallPtrSet<Function *, 8> SCCFunctions; DEBUG(dbgs() << "Inliner visiting SCC:"); for (CallGraphNode *Node : SCC) { Function *F = Node->getFunction(); - if (F) SCCFunctions.insert(F); + if (F) + SCCFunctions.insert(F); DEBUG(dbgs() << " " << (F ? F->getName() : "INDIRECTNODE")); } @@ -391,17 +445,19 @@ bool Inliner::inlineCalls(CallGraphSCC &SCC) { // inline call sites in the original functions, not call sites that result // from inlining other functions. SmallVector<std::pair<CallSite, int>, 16> CallSites; - + // When inlining a callee produces new call sites, we want to keep track of // the fact that they were inlined from the callee. This allows us to avoid // infinite inlining in some obscure cases. To represent this, we use an // index into the InlineHistory vector. - SmallVector<std::pair<Function*, int>, 8> InlineHistory; + SmallVector<std::pair<Function *, int>, 8> InlineHistory; for (CallGraphNode *Node : SCC) { Function *F = Node->getFunction(); - if (!F) continue; - + if (!F || F->isDeclaration()) + continue; + + OptimizationRemarkEmitter ORE(F); for (BasicBlock &BB : *F) for (Instruction &I : BB) { CallSite CS(cast<Value>(&I)); @@ -409,14 +465,21 @@ bool Inliner::inlineCalls(CallGraphSCC &SCC) { // never be inlined. if (!CS || isa<IntrinsicInst>(I)) continue; - + // If this is a direct call to an external function, we can never inline // it. If it is an indirect call, inlining may resolve it to be a // direct call, so we keep it. if (Function *Callee = CS.getCalledFunction()) - if (Callee->isDeclaration()) + if (Callee->isDeclaration()) { + using namespace ore; + ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "NoDefinition", &I) + << NV("Callee", Callee) << " will not be inlined into " + << NV("Caller", CS.getCaller()) + << " because its definition is unavailable" + << setIsVerbose()); continue; - + } + CallSites.push_back(std::make_pair(CS, -1)); } } @@ -435,9 +498,8 @@ bool Inliner::inlineCalls(CallGraphSCC &SCC) { if (SCCFunctions.count(F)) std::swap(CallSites[i--], CallSites[--FirstCallInSCC]); - InlinedArrayAllocasTy InlinedArrayAllocas; - InlineFunctionInfo InlineInfo(&CG, ACT); + InlineFunctionInfo InlineInfo(&CG, &GetAssumptionCache); // Now that we have all of the call sites, loop over them and inline them if // it looks profitable to do so. @@ -450,7 +512,7 @@ bool Inliner::inlineCalls(CallGraphSCC &SCC) { // CallSites may be modified inside so ranged for loop can not be used. for (unsigned CSi = 0; CSi != CallSites.size(); ++CSi) { CallSite CS = CallSites[CSi].first; - + Function *Caller = CS.getCaller(); Function *Callee = CS.getCalledFunction(); @@ -459,16 +521,17 @@ bool Inliner::inlineCalls(CallGraphSCC &SCC) { // size. This happens because IPSCCP propagates the result out of the // call and then we're left with the dead call. if (isInstructionTriviallyDead(CS.getInstruction(), &TLI)) { - DEBUG(dbgs() << " -> Deleting dead call: " - << *CS.getInstruction() << "\n"); + DEBUG(dbgs() << " -> Deleting dead call: " << *CS.getInstruction() + << "\n"); // Update the call graph by deleting the edge from Callee to Caller. CG[Caller]->removeCallEdgeFor(CS); CS.getInstruction()->eraseFromParent(); ++NumCallsDeleted; } else { // We can only inline direct calls to non-declarations. - if (!Callee || Callee->isDeclaration()) continue; - + if (!Callee || Callee->isDeclaration()) + continue; + // If this call site was obtained by inlining another function, verify // that the include path for the function did not include the callee // itself. If so, we'd be recursively inlining the same function, @@ -478,37 +541,42 @@ bool Inliner::inlineCalls(CallGraphSCC &SCC) { if (InlineHistoryID != -1 && InlineHistoryIncludes(Callee, InlineHistoryID, InlineHistory)) continue; - - LLVMContext &CallerCtx = Caller->getContext(); // Get DebugLoc to report. CS will be invalid after Inliner. DebugLoc DLoc = CS.getInstruction()->getDebugLoc(); + BasicBlock *Block = CS.getParent(); + // FIXME for new PM: because of the old PM we currently generate ORE and + // in turn BFI on demand. With the new PM, the ORE dependency should + // just become a regular analysis dependency. + OptimizationRemarkEmitter ORE(Caller); // If the policy determines that we should inline this function, // try to do so. - if (!shouldInline(CS)) { - emitOptimizationRemarkMissed(CallerCtx, DEBUG_TYPE, *Caller, DLoc, - Twine(Callee->getName() + - " will not be inlined into " + - Caller->getName())); + using namespace ore; + if (!shouldInline(CS, GetInlineCost, ORE)) { + ORE.emit( + OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, Block) + << NV("Callee", Callee) << " will not be inlined into " + << NV("Caller", Caller)); continue; } // Attempt to inline the function. - if (!InlineCallIfPossible(*this, CS, InlineInfo, InlinedArrayAllocas, - InlineHistoryID, InsertLifetime)) { - emitOptimizationRemarkMissed(CallerCtx, DEBUG_TYPE, *Caller, DLoc, - Twine(Callee->getName() + - " will not be inlined into " + - Caller->getName())); + if (!InlineCallIfPossible(CS, InlineInfo, InlinedArrayAllocas, + InlineHistoryID, InsertLifetime, AARGetter, + ImportedFunctionsStats)) { + ORE.emit( + OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, Block) + << NV("Callee", Callee) << " will not be inlined into " + << NV("Caller", Caller)); continue; } ++NumInlined; // Report the inline decision. - emitOptimizationRemark( - CallerCtx, DEBUG_TYPE, *Caller, DLoc, - Twine(Callee->getName() + " inlined into " + Caller->getName())); + ORE.emit(OptimizationRemark(DEBUG_TYPE, "Inlined", DLoc, Block) + << NV("Callee", Callee) << " inlined into " + << NV("Caller", Caller)); // If inlining this function gave us any new call sites, throw them // onto our worklist to process. They are useful inline candidates. @@ -522,30 +590,30 @@ bool Inliner::inlineCalls(CallGraphSCC &SCC) { CallSites.push_back(std::make_pair(CallSite(Ptr), NewHistoryID)); } } - + // If we inlined or deleted the last possible call site to the function, // delete the function body now. if (Callee && Callee->use_empty() && Callee->hasLocalLinkage() && // TODO: Can remove if in SCC now. !SCCFunctions.count(Callee) && - + // The function may be apparently dead, but if there are indirect // callgraph references to the node, we cannot delete it yet, this // could invalidate the CGSCC iterator. CG[Callee]->getNumReferences() == 0) { - DEBUG(dbgs() << " -> Deleting dead function: " - << Callee->getName() << "\n"); + DEBUG(dbgs() << " -> Deleting dead function: " << Callee->getName() + << "\n"); CallGraphNode *CalleeNode = CG[Callee]; // Remove any call graph edges from the callee to its callees. CalleeNode->removeAllCalledFunctions(); - + // Removing the node for callee from the call graph and delete it. delete CG.removeFunctionFromModule(CalleeNode); ++NumDeleted; } - // Remove this call site from the list. If possible, use + // Remove this call site from the list. If possible, use // swap/pop_back for efficiency, but do not use it if doing so would // move a call site to a function in this SCC before the // 'FirstCallInSCC' barrier. @@ -553,7 +621,7 @@ bool Inliner::inlineCalls(CallGraphSCC &SCC) { CallSites[CSi] = CallSites.back(); CallSites.pop_back(); } else { - CallSites.erase(CallSites.begin()+CSi); + CallSites.erase(CallSites.begin() + CSi); } --CSi; @@ -565,17 +633,43 @@ bool Inliner::inlineCalls(CallGraphSCC &SCC) { return Changed; } +bool LegacyInlinerBase::inlineCalls(CallGraphSCC &SCC) { + CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); + ACT = &getAnalysis<AssumptionCacheTracker>(); + PSI = getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + // We compute dedicated AA results for each function in the SCC as needed. We + // use a lambda referencing external objects so that they live long enough to + // be queried, but we re-use them each time. + Optional<BasicAAResult> BAR; + Optional<AAResults> AAR; + auto AARGetter = [&](Function &F) -> AAResults & { + BAR.emplace(createLegacyPMBasicAAResult(*this, F)); + AAR.emplace(createLegacyPMAAResults(*this, F, *BAR)); + return *AAR; + }; + auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { + return ACT->getAssumptionCache(F); + }; + return inlineCallsImpl(SCC, CG, GetAssumptionCache, PSI, TLI, InsertLifetime, + [this](CallSite CS) { return getInlineCost(CS); }, + AARGetter, ImportedFunctionsStats); +} + /// Remove now-dead linkonce functions at the end of /// processing to avoid breaking the SCC traversal. -bool Inliner::doFinalization(CallGraph &CG) { +bool LegacyInlinerBase::doFinalization(CallGraph &CG) { + if (InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No) + ImportedFunctionsStats.dump(InlinerFunctionImportStats == + InlinerFunctionImportStatsOpts::Verbose); return removeDeadFunctions(CG); } /// Remove dead functions that are not included in DNR (Do Not Remove) list. -bool Inliner::removeDeadFunctions(CallGraph &CG, bool AlwaysInlineOnly) { - SmallVector<CallGraphNode*, 16> FunctionsToRemove; - SmallVector<CallGraphNode *, 16> DeadFunctionsInComdats; - SmallDenseMap<const Comdat *, int, 16> ComdatEntriesAlive; +bool LegacyInlinerBase::removeDeadFunctions(CallGraph &CG, + bool AlwaysInlineOnly) { + SmallVector<CallGraphNode *, 16> FunctionsToRemove; + SmallVector<Function *, 16> DeadFunctionsInComdats; auto RemoveCGN = [&](CallGraphNode *CGN) { // Remove any call graph edges from the function to its callees. @@ -616,9 +710,8 @@ bool Inliner::removeDeadFunctions(CallGraph &CG, bool AlwaysInlineOnly) { // The inliner doesn't visit non-function entities which are in COMDAT // groups so it is unsafe to do so *unless* the linkage is local. if (!F->hasLocalLinkage()) { - if (const Comdat *C = F->getComdat()) { - --ComdatEntriesAlive[C]; - DeadFunctionsInComdats.push_back(CGN); + if (F->hasComdat()) { + DeadFunctionsInComdats.push_back(F); continue; } } @@ -626,32 +719,11 @@ bool Inliner::removeDeadFunctions(CallGraph &CG, bool AlwaysInlineOnly) { RemoveCGN(CGN); } if (!DeadFunctionsInComdats.empty()) { - // Count up all the entities in COMDAT groups - auto ComdatGroupReferenced = [&](const Comdat *C) { - auto I = ComdatEntriesAlive.find(C); - if (I != ComdatEntriesAlive.end()) - ++(I->getSecond()); - }; - for (const Function &F : CG.getModule()) - if (const Comdat *C = F.getComdat()) - ComdatGroupReferenced(C); - for (const GlobalVariable &GV : CG.getModule().globals()) - if (const Comdat *C = GV.getComdat()) - ComdatGroupReferenced(C); - for (const GlobalAlias &GA : CG.getModule().aliases()) - if (const Comdat *C = GA.getComdat()) - ComdatGroupReferenced(C); - for (CallGraphNode *CGN : DeadFunctionsInComdats) { - Function *F = CGN->getFunction(); - const Comdat *C = F->getComdat(); - int NumAlive = ComdatEntriesAlive[C]; - // We can remove functions in a COMDAT group if the entire group is dead. - assert(NumAlive >= 0); - if (NumAlive > 0) - continue; - - RemoveCGN(CGN); - } + // Filter out the functions whose comdats remain alive. + filterDeadComdatFunctions(CG.getModule(), DeadFunctionsInComdats); + // Remove the rest. + for (Function *F : DeadFunctionsInComdats) + RemoveCGN(CG[F]); } if (FunctionsToRemove.empty()) @@ -665,12 +737,201 @@ bool Inliner::removeDeadFunctions(CallGraph &CG, bool AlwaysInlineOnly) { // here to do this, it doesn't matter which order the functions are deleted // in. array_pod_sort(FunctionsToRemove.begin(), FunctionsToRemove.end()); - FunctionsToRemove.erase(std::unique(FunctionsToRemove.begin(), - FunctionsToRemove.end()), - FunctionsToRemove.end()); + FunctionsToRemove.erase( + std::unique(FunctionsToRemove.begin(), FunctionsToRemove.end()), + FunctionsToRemove.end()); for (CallGraphNode *CGN : FunctionsToRemove) { delete CG.removeFunctionFromModule(CGN); ++NumDeleted; } return true; } + +PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, + CGSCCAnalysisManager &AM, LazyCallGraph &CG, + CGSCCUpdateResult &UR) { + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerCGSCCProxy>(InitialC, CG) + .getManager(); + const ModuleAnalysisManager &MAM = + AM.getResult<ModuleAnalysisManagerCGSCCProxy>(InitialC, CG).getManager(); + bool Changed = false; + + assert(InitialC.size() > 0 && "Cannot handle an empty SCC!"); + Module &M = *InitialC.begin()->getFunction().getParent(); + ProfileSummaryInfo *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(M); + + std::function<AssumptionCache &(Function &)> GetAssumptionCache = + [&](Function &F) -> AssumptionCache & { + return FAM.getResult<AssumptionAnalysis>(F); + }; + + // Setup the data structure used to plumb customization into the + // `InlineFunction` routine. + InlineFunctionInfo IFI(/*cg=*/nullptr, &GetAssumptionCache); + + auto GetInlineCost = [&](CallSite CS) { + Function &Callee = *CS.getCalledFunction(); + auto &CalleeTTI = FAM.getResult<TargetIRAnalysis>(Callee); + return getInlineCost(CS, Params, CalleeTTI, GetAssumptionCache, PSI); + }; + + // We use a worklist of nodes to process so that we can handle if the SCC + // structure changes and some nodes are no longer part of the current SCC. We + // also need to use an updatable pointer for the SCC as a consequence. + SmallVector<LazyCallGraph::Node *, 16> Nodes; + for (auto &N : InitialC) + Nodes.push_back(&N); + auto *C = &InitialC; + auto *RC = &C->getOuterRefSCC(); + + // We also use a secondary worklist of call sites within a particular node to + // allow quickly continuing to inline through newly inlined call sites where + // possible. + SmallVector<std::pair<CallSite, int>, 16> Calls; + + // When inlining a callee produces new call sites, we want to keep track of + // the fact that they were inlined from the callee. This allows us to avoid + // infinite inlining in some obscure cases. To represent this, we use an + // index into the InlineHistory vector. + SmallVector<std::pair<Function *, int>, 16> InlineHistory; + + // Track a set vector of inlined callees so that we can augment the caller + // with all of their edges in the call graph before pruning out the ones that + // got simplified away. + SmallSetVector<Function *, 4> InlinedCallees; + + // Track the dead functions to delete once finished with inlining calls. We + // defer deleting these to make it easier to handle the call graph updates. + SmallVector<Function *, 4> DeadFunctions; + + do { + auto &N = *Nodes.pop_back_val(); + if (CG.lookupSCC(N) != C) + continue; + Function &F = N.getFunction(); + if (F.hasFnAttribute(Attribute::OptimizeNone)) + continue; + + // Get the remarks emission analysis for the caller. + auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); + + // We want to generally process call sites top-down in order for + // simplifications stemming from replacing the call with the returned value + // after inlining to be visible to subsequent inlining decisions. So we + // walk the function backwards and then process the back of the vector. + // FIXME: Using reverse is a really bad way to do this. Instead we should + // do an actual PO walk of the function body. + for (Instruction &I : reverse(instructions(F))) + if (auto CS = CallSite(&I)) + if (Function *Callee = CS.getCalledFunction()) + if (!Callee->isDeclaration()) + Calls.push_back({CS, -1}); + + bool DidInline = false; + while (!Calls.empty()) { + int InlineHistoryID; + CallSite CS; + std::tie(CS, InlineHistoryID) = Calls.pop_back_val(); + Function &Callee = *CS.getCalledFunction(); + + if (InlineHistoryID != -1 && + InlineHistoryIncludes(&Callee, InlineHistoryID, InlineHistory)) + continue; + + // Check whether we want to inline this callsite. + if (!shouldInline(CS, GetInlineCost, ORE)) + continue; + + if (!InlineFunction(CS, IFI)) + continue; + DidInline = true; + InlinedCallees.insert(&Callee); + + // Add any new callsites to defined functions to the worklist. + if (!IFI.InlinedCallSites.empty()) { + int NewHistoryID = InlineHistory.size(); + InlineHistory.push_back({&Callee, InlineHistoryID}); + for (CallSite &CS : reverse(IFI.InlinedCallSites)) + if (Function *NewCallee = CS.getCalledFunction()) + if (!NewCallee->isDeclaration()) + Calls.push_back({CS, NewHistoryID}); + } + + // Merge the attributes based on the inlining. + AttributeFuncs::mergeAttributesForInlining(F, Callee); + + // For local functions, check whether this makes the callee trivially + // dead. In that case, we can drop the body of the function eagerly + // which may reduce the number of callers of other functions to one, + // changing inline cost thresholds. + if (Callee.hasLocalLinkage()) { + // To check this we also need to nuke any dead constant uses (perhaps + // made dead by this operation on other functions). + Callee.removeDeadConstantUsers(); + if (Callee.use_empty()) { + // Clear the body and queue the function itself for deletion when we + // finish inlining and call graph updates. + // Note that after this point, it is an error to do anything other + // than use the callee's address or delete it. + Callee.dropAllReferences(); + assert(find(DeadFunctions, &Callee) == DeadFunctions.end() && + "Cannot put cause a function to become dead twice!"); + DeadFunctions.push_back(&Callee); + } + } + } + + if (!DidInline) + continue; + Changed = true; + + // Add all the inlined callees' edges as ref edges to the caller. These are + // by definition trivial edges as we always have *some* transitive ref edge + // chain. While in some cases these edges are direct calls inside the + // callee, they have to be modeled in the inliner as reference edges as + // there may be a reference edge anywhere along the chain from the current + // caller to the callee that causes the whole thing to appear like + // a (transitive) reference edge that will require promotion to a call edge + // below. + for (Function *InlinedCallee : InlinedCallees) { + LazyCallGraph::Node &CalleeN = *CG.lookup(*InlinedCallee); + for (LazyCallGraph::Edge &E : CalleeN) + RC->insertTrivialRefEdge(N, *E.getNode()); + } + InlinedCallees.clear(); + + // At this point, since we have made changes we have at least removed + // a call instruction. However, in the process we do some incremental + // simplification of the surrounding code. This simplification can + // essentially do all of the same things as a function pass and we can + // re-use the exact same logic for updating the call graph to reflect the + // change.. + C = &updateCGAndAnalysisManagerForFunctionPass(CG, *C, N, AM, UR); + RC = &C->getOuterRefSCC(); + } while (!Nodes.empty()); + + // Now that we've finished inlining all of the calls across this SCC, delete + // all of the trivially dead functions, updating the call graph and the CGSCC + // pass manager in the process. + // + // Note that this walks a pointer set which has non-deterministic order but + // that is OK as all we do is delete things and add pointers to unordered + // sets. + for (Function *DeadF : DeadFunctions) { + // Get the necessary information out of the call graph and nuke the + // function there. + auto &DeadC = *CG.lookupSCC(*CG.lookup(*DeadF)); + auto &DeadRC = DeadC.getOuterRefSCC(); + CG.removeDeadFunction(*DeadF); + + // Mark the relevant parts of the call graph as invalid so we don't visit + // them. + UR.InvalidatedSCCs.insert(&DeadC); + UR.InvalidatedRefSCCs.insert(&DeadRC); + + // And delete the actual function from the module. + M.getFunctionList().erase(DeadF); + } + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/contrib/llvm/lib/Transforms/IPO/Internalize.cpp b/contrib/llvm/lib/Transforms/IPO/Internalize.cpp index 8c5c6f7..26db146 100644 --- a/contrib/llvm/lib/Transforms/IPO/Internalize.cpp +++ b/contrib/llvm/lib/Transforms/IPO/Internalize.cpp @@ -239,7 +239,7 @@ bool InternalizePass::internalizeModule(Module &M, CallGraph *CG) { InternalizePass::InternalizePass() : MustPreserveGV(PreserveAPIList()) {} -PreservedAnalyses InternalizePass::run(Module &M, AnalysisManager<Module> &AM) { +PreservedAnalyses InternalizePass::run(Module &M, ModuleAnalysisManager &AM) { if (!internalizeModule(M, AM.getCachedResult<CallGraphAnalysis>(M))) return PreservedAnalyses::all(); diff --git a/contrib/llvm/lib/Transforms/IPO/LowerTypeTests.cpp b/contrib/llvm/lib/Transforms/IPO/LowerTypeTests.cpp index 36089f0..deb7e81 100644 --- a/contrib/llvm/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/contrib/llvm/lib/Transforms/IPO/LowerTypeTests.cpp @@ -13,8 +13,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/LowerTypeTests.h" -#include "llvm/Transforms/IPO.h" #include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/Constant.h" @@ -23,18 +23,27 @@ #include "llvm/IR/GlobalObject.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InlineAsm.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" +#include "llvm/IR/ModuleSummaryIndexYAML.h" #include "llvm/IR/Operator.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/TrailingObjects.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" using namespace llvm; using namespace lowertypetests; +using SummaryAction = LowerTypeTestsSummaryAction; + #define DEBUG_TYPE "lowertypetests" STATISTIC(ByteArraySizeBits, "Byte array size in bits"); @@ -48,6 +57,26 @@ static cl::opt<bool> AvoidReuse( cl::desc("Try to avoid reuse of byte array addresses using aliases"), cl::Hidden, cl::init(true)); +static cl::opt<SummaryAction> ClSummaryAction( + "lowertypetests-summary-action", + cl::desc("What to do with the summary when running this pass"), + cl::values(clEnumValN(SummaryAction::None, "none", "Do nothing"), + clEnumValN(SummaryAction::Import, "import", + "Import typeid resolutions from summary and globals"), + clEnumValN(SummaryAction::Export, "export", + "Export typeid resolutions to summary and globals")), + cl::Hidden); + +static cl::opt<std::string> ClReadSummary( + "lowertypetests-read-summary", + cl::desc("Read summary from given YAML file before running pass"), + cl::Hidden); + +static cl::opt<std::string> ClWriteSummary( + "lowertypetests-write-summary", + cl::desc("Write summary to given YAML file after running pass"), + cl::Hidden); + bool BitSetInfo::containsGlobalOffset(uint64_t Offset) const { if (Offset < ByteOffset) return false; @@ -62,39 +91,6 @@ bool BitSetInfo::containsGlobalOffset(uint64_t Offset) const { return Bits.count(BitOffset); } -bool BitSetInfo::containsValue( - const DataLayout &DL, - const DenseMap<GlobalObject *, uint64_t> &GlobalLayout, Value *V, - uint64_t COffset) const { - if (auto GV = dyn_cast<GlobalObject>(V)) { - auto I = GlobalLayout.find(GV); - if (I == GlobalLayout.end()) - return false; - return containsGlobalOffset(I->second + COffset); - } - - if (auto GEP = dyn_cast<GEPOperator>(V)) { - APInt APOffset(DL.getPointerSizeInBits(0), 0); - bool Result = GEP->accumulateConstantOffset(DL, APOffset); - if (!Result) - return false; - COffset += APOffset.getZExtValue(); - return containsValue(DL, GlobalLayout, GEP->getPointerOperand(), - COffset); - } - - if (auto Op = dyn_cast<Operator>(V)) { - if (Op->getOpcode() == Instruction::BitCast) - return containsValue(DL, GlobalLayout, Op->getOperand(0), COffset); - - if (Op->getOpcode() == Instruction::Select) - return containsValue(DL, GlobalLayout, Op->getOperand(1), COffset) && - containsValue(DL, GlobalLayout, Op->getOperand(2), COffset); - } - - return false; -} - void BitSetInfo::print(raw_ostream &OS) const { OS << "offset " << ByteOffset << " size " << BitSize << " align " << (1 << AlignLog2); @@ -201,59 +197,169 @@ struct ByteArrayInfo { std::set<uint64_t> Bits; uint64_t BitSize; GlobalVariable *ByteArray; - Constant *Mask; + GlobalVariable *MaskGlobal; }; -struct LowerTypeTests : public ModulePass { - static char ID; - LowerTypeTests() : ModulePass(ID) { - initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry()); +/// A POD-like structure that we use to store a global reference together with +/// its metadata types. In this pass we frequently need to query the set of +/// metadata types referenced by a global, which at the IR level is an expensive +/// operation involving a map lookup; this data structure helps to reduce the +/// number of times we need to do this lookup. +class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> { + GlobalObject *GO; + size_t NTypes; + + friend TrailingObjects; + size_t numTrailingObjects(OverloadToken<MDNode *>) const { return NTypes; } + +public: + static GlobalTypeMember *create(BumpPtrAllocator &Alloc, GlobalObject *GO, + ArrayRef<MDNode *> Types) { + auto *GTM = static_cast<GlobalTypeMember *>(Alloc.Allocate( + totalSizeToAlloc<MDNode *>(Types.size()), alignof(GlobalTypeMember))); + GTM->GO = GO; + GTM->NTypes = Types.size(); + std::uninitialized_copy(Types.begin(), Types.end(), + GTM->getTrailingObjects<MDNode *>()); + return GTM; } + GlobalObject *getGlobal() const { + return GO; + } + ArrayRef<MDNode *> types() const { + return makeArrayRef(getTrailingObjects<MDNode *>(), NTypes); + } +}; - Module *M; +class LowerTypeTestsModule { + Module &M; + + SummaryAction Action; + ModuleSummaryIndex *Summary; bool LinkerSubsectionsViaSymbols; Triple::ArchType Arch; + Triple::OSType OS; Triple::ObjectFormatType ObjectFormat; - IntegerType *Int1Ty; - IntegerType *Int8Ty; - IntegerType *Int32Ty; - Type *Int32PtrTy; - IntegerType *Int64Ty; - IntegerType *IntPtrTy; + + IntegerType *Int1Ty = Type::getInt1Ty(M.getContext()); + IntegerType *Int8Ty = Type::getInt8Ty(M.getContext()); + PointerType *Int8PtrTy = Type::getInt8PtrTy(M.getContext()); + IntegerType *Int32Ty = Type::getInt32Ty(M.getContext()); + PointerType *Int32PtrTy = PointerType::getUnqual(Int32Ty); + IntegerType *Int64Ty = Type::getInt64Ty(M.getContext()); + IntegerType *IntPtrTy = M.getDataLayout().getIntPtrType(M.getContext(), 0); + + // Indirect function call index assignment counter for WebAssembly + uint64_t IndirectIndex = 1; // Mapping from type identifiers to the call sites that test them. DenseMap<Metadata *, std::vector<CallInst *>> TypeTestCallSites; + /// This structure describes how to lower type tests for a particular type + /// identifier. It is either built directly from the global analysis (during + /// regular LTO or the regular LTO phase of ThinLTO), or indirectly using type + /// identifier summaries and external symbol references (in ThinLTO backends). + struct TypeIdLowering { + TypeTestResolution::Kind TheKind; + + /// All except Unsat: the start address within the combined global. + Constant *OffsetedGlobal; + + /// ByteArray, Inline, AllOnes: log2 of the required global alignment + /// relative to the start address. + Constant *AlignLog2; + + /// ByteArray, Inline, AllOnes: one less than the size of the memory region + /// covering members of this type identifier as a multiple of 2^AlignLog2. + Constant *SizeM1; + + /// ByteArray, Inline, AllOnes: range of SizeM1 expressed as a bit width. + unsigned SizeM1BitWidth; + + /// ByteArray: the byte array to test the address against. + Constant *TheByteArray; + + /// ByteArray: the bit mask to apply to bytes loaded from the byte array. + Constant *BitMask; + + /// Inline: the bit mask to test the address against. + Constant *InlineBits; + }; + std::vector<ByteArrayInfo> ByteArrayInfos; + Function *WeakInitializerFn = nullptr; + BitSetInfo buildBitSet(Metadata *TypeId, - const DenseMap<GlobalObject *, uint64_t> &GlobalLayout); + const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout); ByteArrayInfo *createByteArray(BitSetInfo &BSI); void allocateByteArrays(); - Value *createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI, ByteArrayInfo *&BAI, + Value *createBitSetTest(IRBuilder<> &B, const TypeIdLowering &TIL, Value *BitOffset); - void - lowerTypeTestCalls(ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr, - const DenseMap<GlobalObject *, uint64_t> &GlobalLayout); - Value * - lowerBitSetCall(CallInst *CI, BitSetInfo &BSI, ByteArrayInfo *&BAI, - Constant *CombinedGlobal, - const DenseMap<GlobalObject *, uint64_t> &GlobalLayout); + void lowerTypeTestCalls( + ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr, + const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout); + Value *lowerTypeTestCall(Metadata *TypeId, CallInst *CI, + const TypeIdLowering &TIL); void buildBitSetsFromGlobalVariables(ArrayRef<Metadata *> TypeIds, - ArrayRef<GlobalVariable *> Globals); + ArrayRef<GlobalTypeMember *> Globals); unsigned getJumpTableEntrySize(); Type *getJumpTableEntryType(); - Constant *createJumpTableEntry(GlobalObject *Src, Function *Dest, - unsigned Distance); + void createJumpTableEntry(raw_ostream &AsmOS, raw_ostream &ConstraintOS, + SmallVectorImpl<Value *> &AsmArgs, Function *Dest); void verifyTypeMDNode(GlobalObject *GO, MDNode *Type); void buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds, - ArrayRef<Function *> Functions); + ArrayRef<GlobalTypeMember *> Functions); + void buildBitSetsFromFunctionsNative(ArrayRef<Metadata *> TypeIds, + ArrayRef<GlobalTypeMember *> Functions); + void buildBitSetsFromFunctionsWASM(ArrayRef<Metadata *> TypeIds, + ArrayRef<GlobalTypeMember *> Functions); void buildBitSetsFromDisjointSet(ArrayRef<Metadata *> TypeIds, - ArrayRef<GlobalObject *> Globals); + ArrayRef<GlobalTypeMember *> Globals); + + void replaceWeakDeclarationWithJumpTablePtr(Function *F, Constant *JT); + void moveInitializerToModuleConstructor(GlobalVariable *GV); + void findGlobalVariableUsersOf(Constant *C, + SmallSetVector<GlobalVariable *, 8> &Out); + + void createJumpTable(Function *F, ArrayRef<GlobalTypeMember *> Functions); + +public: + LowerTypeTestsModule(Module &M, SummaryAction Action, + ModuleSummaryIndex *Summary); bool lower(); - bool runOnModule(Module &M) override; + + // Lower the module using the action and summary passed as command line + // arguments. For testing purposes only. + static bool runForTesting(Module &M); +}; + +struct LowerTypeTests : public ModulePass { + static char ID; + + bool UseCommandLine = false; + + SummaryAction Action; + ModuleSummaryIndex *Summary; + + LowerTypeTests() : ModulePass(ID), UseCommandLine(true) { + initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry()); + } + + LowerTypeTests(SummaryAction Action, ModuleSummaryIndex *Summary) + : ModulePass(ID), Action(Action), Summary(Summary) { + initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + if (UseCommandLine) + return LowerTypeTestsModule::runForTesting(M); + return LowerTypeTestsModule(M, Action, Summary).lower(); + } }; } // anonymous namespace @@ -262,27 +368,28 @@ INITIALIZE_PASS(LowerTypeTests, "lowertypetests", "Lower type metadata", false, false) char LowerTypeTests::ID = 0; -ModulePass *llvm::createLowerTypeTestsPass() { return new LowerTypeTests; } +ModulePass *llvm::createLowerTypeTestsPass(SummaryAction Action, + ModuleSummaryIndex *Summary) { + return new LowerTypeTests(Action, Summary); +} /// Build a bit set for TypeId using the object layouts in /// GlobalLayout. -BitSetInfo LowerTypeTests::buildBitSet( +BitSetInfo LowerTypeTestsModule::buildBitSet( Metadata *TypeId, - const DenseMap<GlobalObject *, uint64_t> &GlobalLayout) { + const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) { BitSetBuilder BSB; // Compute the byte offset of each address associated with this type // identifier. - SmallVector<MDNode *, 2> Types; for (auto &GlobalAndOffset : GlobalLayout) { - Types.clear(); - GlobalAndOffset.first->getMetadata(LLVMContext::MD_type, Types); - for (MDNode *Type : Types) { + for (MDNode *Type : GlobalAndOffset.first->types()) { if (Type->getOperand(1) != TypeId) continue; uint64_t Offset = - cast<ConstantInt>(cast<ConstantAsMetadata>(Type->getOperand(0)) - ->getValue())->getZExtValue(); + cast<ConstantInt>( + cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) + ->getZExtValue(); BSB.addOffset(GlobalAndOffset.second + Offset); } } @@ -305,14 +412,14 @@ static Value *createMaskedBitTest(IRBuilder<> &B, Value *Bits, return B.CreateICmpNE(MaskedBits, ConstantInt::get(BitsType, 0)); } -ByteArrayInfo *LowerTypeTests::createByteArray(BitSetInfo &BSI) { +ByteArrayInfo *LowerTypeTestsModule::createByteArray(BitSetInfo &BSI) { // Create globals to stand in for byte arrays and masks. These never actually // get initialized, we RAUW and erase them later in allocateByteArrays() once // we know the offset and mask to use. auto ByteArrayGlobal = new GlobalVariable( - *M, Int8Ty, /*isConstant=*/true, GlobalValue::PrivateLinkage, nullptr); - auto MaskGlobal = new GlobalVariable( - *M, Int8Ty, /*isConstant=*/true, GlobalValue::PrivateLinkage, nullptr); + M, Int8Ty, /*isConstant=*/true, GlobalValue::PrivateLinkage, nullptr); + auto MaskGlobal = new GlobalVariable(M, Int8Ty, /*isConstant=*/true, + GlobalValue::PrivateLinkage, nullptr); ByteArrayInfos.emplace_back(); ByteArrayInfo *BAI = &ByteArrayInfos.back(); @@ -320,11 +427,11 @@ ByteArrayInfo *LowerTypeTests::createByteArray(BitSetInfo &BSI) { BAI->Bits = BSI.Bits; BAI->BitSize = BSI.BitSize; BAI->ByteArray = ByteArrayGlobal; - BAI->Mask = ConstantExpr::getPtrToInt(MaskGlobal, Int8Ty); + BAI->MaskGlobal = MaskGlobal; return BAI; } -void LowerTypeTests::allocateByteArrays() { +void LowerTypeTestsModule::allocateByteArrays() { std::stable_sort(ByteArrayInfos.begin(), ByteArrayInfos.end(), [](const ByteArrayInfo &BAI1, const ByteArrayInfo &BAI2) { return BAI1.BitSize > BAI2.BitSize; @@ -339,13 +446,14 @@ void LowerTypeTests::allocateByteArrays() { uint8_t Mask; BAB.allocate(BAI->Bits, BAI->BitSize, ByteArrayOffsets[I], Mask); - BAI->Mask->replaceAllUsesWith(ConstantInt::get(Int8Ty, Mask)); - cast<GlobalVariable>(BAI->Mask->getOperand(0))->eraseFromParent(); + BAI->MaskGlobal->replaceAllUsesWith( + ConstantExpr::getIntToPtr(ConstantInt::get(Int8Ty, Mask), Int8PtrTy)); + BAI->MaskGlobal->eraseFromParent(); } - Constant *ByteArrayConst = ConstantDataArray::get(M->getContext(), BAB.Bytes); + Constant *ByteArrayConst = ConstantDataArray::get(M.getContext(), BAB.Bytes); auto ByteArray = - new GlobalVariable(*M, ByteArrayConst->getType(), /*isConstant=*/true, + new GlobalVariable(M, ByteArrayConst->getType(), /*isConstant=*/true, GlobalValue::PrivateLinkage, ByteArrayConst); for (unsigned I = 0; I != ByteArrayInfos.size(); ++I) { @@ -363,7 +471,7 @@ void LowerTypeTests::allocateByteArrays() { BAI->ByteArray->replaceAllUsesWith(GEP); } else { GlobalAlias *Alias = GlobalAlias::create( - Int8Ty, 0, GlobalValue::PrivateLinkage, "bits", GEP, M); + Int8Ty, 0, GlobalValue::PrivateLinkage, "bits", GEP, &M); BAI->ByteArray->replaceAllUsesWith(Alias); } BAI->ByteArray->eraseFromParent(); @@ -375,63 +483,84 @@ void LowerTypeTests::allocateByteArrays() { ByteArraySizeBytes = BAB.Bytes.size(); } -/// Build a test that bit BitOffset is set in BSI, where -/// BitSetGlobal is a global containing the bits in BSI. -Value *LowerTypeTests::createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI, - ByteArrayInfo *&BAI, Value *BitOffset) { - if (BSI.BitSize <= 64) { +/// Build a test that bit BitOffset is set in the type identifier that was +/// lowered to TIL, which must be either an Inline or a ByteArray. +Value *LowerTypeTestsModule::createBitSetTest(IRBuilder<> &B, + const TypeIdLowering &TIL, + Value *BitOffset) { + if (TIL.TheKind == TypeTestResolution::Inline) { // If the bit set is sufficiently small, we can avoid a load by bit testing // a constant. - IntegerType *BitsTy; - if (BSI.BitSize <= 32) - BitsTy = Int32Ty; - else - BitsTy = Int64Ty; - - uint64_t Bits = 0; - for (auto Bit : BSI.Bits) - Bits |= uint64_t(1) << Bit; - Constant *BitsConst = ConstantInt::get(BitsTy, Bits); - return createMaskedBitTest(B, BitsConst, BitOffset); + return createMaskedBitTest(B, TIL.InlineBits, BitOffset); } else { - if (!BAI) { - ++NumByteArraysCreated; - BAI = createByteArray(BSI); - } - - Constant *ByteArray = BAI->ByteArray; - Type *Ty = BAI->ByteArray->getValueType(); + Constant *ByteArray = TIL.TheByteArray; if (!LinkerSubsectionsViaSymbols && AvoidReuse) { // Each use of the byte array uses a different alias. This makes the // backend less likely to reuse previously computed byte array addresses, // improving the security of the CFI mechanism based on this pass. - ByteArray = GlobalAlias::create(BAI->ByteArray->getValueType(), 0, - GlobalValue::PrivateLinkage, "bits_use", - ByteArray, M); + ByteArray = GlobalAlias::create(Int8Ty, 0, GlobalValue::PrivateLinkage, + "bits_use", ByteArray, &M); } - Value *ByteAddr = B.CreateGEP(Ty, ByteArray, BitOffset); + Value *ByteAddr = B.CreateGEP(Int8Ty, ByteArray, BitOffset); Value *Byte = B.CreateLoad(ByteAddr); - Value *ByteAndMask = B.CreateAnd(Byte, BAI->Mask); + Value *ByteAndMask = + B.CreateAnd(Byte, ConstantExpr::getPtrToInt(TIL.BitMask, Int8Ty)); return B.CreateICmpNE(ByteAndMask, ConstantInt::get(Int8Ty, 0)); } } +static bool isKnownTypeIdMember(Metadata *TypeId, const DataLayout &DL, + Value *V, uint64_t COffset) { + if (auto GV = dyn_cast<GlobalObject>(V)) { + SmallVector<MDNode *, 2> Types; + GV->getMetadata(LLVMContext::MD_type, Types); + for (MDNode *Type : Types) { + if (Type->getOperand(1) != TypeId) + continue; + uint64_t Offset = + cast<ConstantInt>( + cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) + ->getZExtValue(); + if (COffset == Offset) + return true; + } + return false; + } + + if (auto GEP = dyn_cast<GEPOperator>(V)) { + APInt APOffset(DL.getPointerSizeInBits(0), 0); + bool Result = GEP->accumulateConstantOffset(DL, APOffset); + if (!Result) + return false; + COffset += APOffset.getZExtValue(); + return isKnownTypeIdMember(TypeId, DL, GEP->getPointerOperand(), COffset); + } + + if (auto Op = dyn_cast<Operator>(V)) { + if (Op->getOpcode() == Instruction::BitCast) + return isKnownTypeIdMember(TypeId, DL, Op->getOperand(0), COffset); + + if (Op->getOpcode() == Instruction::Select) + return isKnownTypeIdMember(TypeId, DL, Op->getOperand(1), COffset) && + isKnownTypeIdMember(TypeId, DL, Op->getOperand(2), COffset); + } + + return false; +} + /// Lower a llvm.type.test call to its implementation. Returns the value to /// replace the call with. -Value *LowerTypeTests::lowerBitSetCall( - CallInst *CI, BitSetInfo &BSI, ByteArrayInfo *&BAI, - Constant *CombinedGlobalIntAddr, - const DenseMap<GlobalObject *, uint64_t> &GlobalLayout) { - Value *Ptr = CI->getArgOperand(0); - const DataLayout &DL = M->getDataLayout(); - - if (BSI.containsValue(DL, GlobalLayout, Ptr)) - return ConstantInt::getTrue(M->getContext()); +Value *LowerTypeTestsModule::lowerTypeTestCall(Metadata *TypeId, CallInst *CI, + const TypeIdLowering &TIL) { + if (TIL.TheKind == TypeTestResolution::Unsat) + return ConstantInt::getFalse(M.getContext()); - Constant *OffsetedGlobalAsInt = ConstantExpr::getAdd( - CombinedGlobalIntAddr, ConstantInt::get(IntPtrTy, BSI.ByteOffset)); + Value *Ptr = CI->getArgOperand(0); + const DataLayout &DL = M.getDataLayout(); + if (isKnownTypeIdMember(TypeId, DL, Ptr, 0)) + return ConstantInt::getTrue(M.getContext()); BasicBlock *InitialBB = CI->getParent(); @@ -439,36 +568,36 @@ Value *LowerTypeTests::lowerBitSetCall( Value *PtrAsInt = B.CreatePtrToInt(Ptr, IntPtrTy); - if (BSI.isSingleOffset()) + Constant *OffsetedGlobalAsInt = + ConstantExpr::getPtrToInt(TIL.OffsetedGlobal, IntPtrTy); + if (TIL.TheKind == TypeTestResolution::Single) return B.CreateICmpEQ(PtrAsInt, OffsetedGlobalAsInt); Value *PtrOffset = B.CreateSub(PtrAsInt, OffsetedGlobalAsInt); - Value *BitOffset; - if (BSI.AlignLog2 == 0) { - BitOffset = PtrOffset; - } else { - // We need to check that the offset both falls within our range and is - // suitably aligned. We can check both properties at the same time by - // performing a right rotate by log2(alignment) followed by an integer - // comparison against the bitset size. The rotate will move the lower - // order bits that need to be zero into the higher order bits of the - // result, causing the comparison to fail if they are nonzero. The rotate - // also conveniently gives us a bit offset to use during the load from - // the bitset. - Value *OffsetSHR = - B.CreateLShr(PtrOffset, ConstantInt::get(IntPtrTy, BSI.AlignLog2)); - Value *OffsetSHL = B.CreateShl( - PtrOffset, - ConstantInt::get(IntPtrTy, DL.getPointerSizeInBits(0) - BSI.AlignLog2)); - BitOffset = B.CreateOr(OffsetSHR, OffsetSHL); - } - - Constant *BitSizeConst = ConstantInt::get(IntPtrTy, BSI.BitSize); - Value *OffsetInRange = B.CreateICmpULT(BitOffset, BitSizeConst); + // We need to check that the offset both falls within our range and is + // suitably aligned. We can check both properties at the same time by + // performing a right rotate by log2(alignment) followed by an integer + // comparison against the bitset size. The rotate will move the lower + // order bits that need to be zero into the higher order bits of the + // result, causing the comparison to fail if they are nonzero. The rotate + // also conveniently gives us a bit offset to use during the load from + // the bitset. + Value *OffsetSHR = + B.CreateLShr(PtrOffset, ConstantExpr::getZExt(TIL.AlignLog2, IntPtrTy)); + Value *OffsetSHL = B.CreateShl( + PtrOffset, ConstantExpr::getZExt( + ConstantExpr::getSub( + ConstantInt::get(Int8Ty, DL.getPointerSizeInBits(0)), + TIL.AlignLog2), + IntPtrTy)); + Value *BitOffset = B.CreateOr(OffsetSHR, OffsetSHL); + + Constant *BitSizeConst = ConstantExpr::getZExt(TIL.SizeM1, IntPtrTy); + Value *OffsetInRange = B.CreateICmpULE(BitOffset, BitSizeConst); // If the bit set is all ones, testing against it is unnecessary. - if (BSI.isAllOnes()) + if (TIL.TheKind == TypeTestResolution::AllOnes) return OffsetInRange; TerminatorInst *Term = SplitBlockAndInsertIfThen(OffsetInRange, CI, false); @@ -476,7 +605,7 @@ Value *LowerTypeTests::lowerBitSetCall( // Now that we know that the offset is in range and aligned, load the // appropriate bit from the bitset. - Value *Bit = createBitSetTest(ThenB, BSI, BAI, BitOffset); + Value *Bit = createBitSetTest(ThenB, TIL, BitOffset); // The value we want is 0 if we came directly from the initial block // (having failed the range or alignment checks), or the loaded bit if @@ -490,17 +619,18 @@ Value *LowerTypeTests::lowerBitSetCall( /// Given a disjoint set of type identifiers and globals, lay out the globals, /// build the bit sets and lower the llvm.type.test calls. -void LowerTypeTests::buildBitSetsFromGlobalVariables( - ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalVariable *> Globals) { +void LowerTypeTestsModule::buildBitSetsFromGlobalVariables( + ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Globals) { // Build a new global with the combined contents of the referenced globals. // This global is a struct whose even-indexed elements contain the original // contents of the referenced globals and whose odd-indexed elements contain // any padding required to align the next element to the next power of 2. std::vector<Constant *> GlobalInits; - const DataLayout &DL = M->getDataLayout(); - for (GlobalVariable *G : Globals) { - GlobalInits.push_back(G->getInitializer()); - uint64_t InitSize = DL.getTypeAllocSize(G->getValueType()); + const DataLayout &DL = M.getDataLayout(); + for (GlobalTypeMember *G : Globals) { + GlobalVariable *GV = cast<GlobalVariable>(G->getGlobal()); + GlobalInits.push_back(GV->getInitializer()); + uint64_t InitSize = DL.getTypeAllocSize(GV->getValueType()); // Compute the amount of padding required. uint64_t Padding = NextPowerOf2(InitSize - 1) - InitSize; @@ -515,16 +645,16 @@ void LowerTypeTests::buildBitSetsFromGlobalVariables( } if (!GlobalInits.empty()) GlobalInits.pop_back(); - Constant *NewInit = ConstantStruct::getAnon(M->getContext(), GlobalInits); + Constant *NewInit = ConstantStruct::getAnon(M.getContext(), GlobalInits); auto *CombinedGlobal = - new GlobalVariable(*M, NewInit->getType(), /*isConstant=*/true, + new GlobalVariable(M, NewInit->getType(), /*isConstant=*/true, GlobalValue::PrivateLinkage, NewInit); StructType *NewTy = cast<StructType>(NewInit->getType()); const StructLayout *CombinedGlobalLayout = DL.getStructLayout(NewTy); // Compute the offsets of the original globals within the new global. - DenseMap<GlobalObject *, uint64_t> GlobalLayout; + DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout; for (unsigned I = 0; I != Globals.size(); ++I) // Multiply by 2 to account for padding elements. GlobalLayout[Globals[I]] = CombinedGlobalLayout->getElementOffset(I * 2); @@ -535,31 +665,32 @@ void LowerTypeTests::buildBitSetsFromGlobalVariables( // global from which we built the combined global, and replace references // to the original globals with references to the aliases. for (unsigned I = 0; I != Globals.size(); ++I) { + GlobalVariable *GV = cast<GlobalVariable>(Globals[I]->getGlobal()); + // Multiply by 2 to account for padding elements. Constant *CombinedGlobalIdxs[] = {ConstantInt::get(Int32Ty, 0), ConstantInt::get(Int32Ty, I * 2)}; Constant *CombinedGlobalElemPtr = ConstantExpr::getGetElementPtr( NewInit->getType(), CombinedGlobal, CombinedGlobalIdxs); if (LinkerSubsectionsViaSymbols) { - Globals[I]->replaceAllUsesWith(CombinedGlobalElemPtr); + GV->replaceAllUsesWith(CombinedGlobalElemPtr); } else { - assert(Globals[I]->getType()->getAddressSpace() == 0); + assert(GV->getType()->getAddressSpace() == 0); GlobalAlias *GAlias = GlobalAlias::create(NewTy->getElementType(I * 2), 0, - Globals[I]->getLinkage(), "", - CombinedGlobalElemPtr, M); - GAlias->setVisibility(Globals[I]->getVisibility()); - GAlias->takeName(Globals[I]); - Globals[I]->replaceAllUsesWith(GAlias); + GV->getLinkage(), "", + CombinedGlobalElemPtr, &M); + GAlias->setVisibility(GV->getVisibility()); + GAlias->takeName(GV); + GV->replaceAllUsesWith(GAlias); } - Globals[I]->eraseFromParent(); + GV->eraseFromParent(); } } -void LowerTypeTests::lowerTypeTestCalls( +void LowerTypeTestsModule::lowerTypeTestCalls( ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr, - const DenseMap<GlobalObject *, uint64_t> &GlobalLayout) { - Constant *CombinedGlobalIntAddr = - ConstantExpr::getPtrToInt(CombinedGlobalAddr, IntPtrTy); + const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) { + CombinedGlobalAddr = ConstantExpr::getBitCast(CombinedGlobalAddr, Int8PtrTy); // For each type identifier in this disjoint set... for (Metadata *TypeId : TypeIds) { @@ -573,23 +704,52 @@ void LowerTypeTests::lowerTypeTestCalls( BSI.print(dbgs()); }); - ByteArrayInfo *BAI = nullptr; + TypeIdLowering TIL; + TIL.OffsetedGlobal = ConstantExpr::getGetElementPtr( + Int8Ty, CombinedGlobalAddr, ConstantInt::get(IntPtrTy, BSI.ByteOffset)), + TIL.AlignLog2 = ConstantInt::get(Int8Ty, BSI.AlignLog2); + if (BSI.isAllOnes()) { + TIL.TheKind = (BSI.BitSize == 1) ? TypeTestResolution::Single + : TypeTestResolution::AllOnes; + TIL.SizeM1BitWidth = (BSI.BitSize <= 128) ? 7 : 32; + TIL.SizeM1 = ConstantInt::get((BSI.BitSize <= 128) ? Int8Ty : Int32Ty, + BSI.BitSize - 1); + } else if (BSI.BitSize <= 64) { + TIL.TheKind = TypeTestResolution::Inline; + TIL.SizeM1BitWidth = (BSI.BitSize <= 32) ? 5 : 6; + TIL.SizeM1 = ConstantInt::get(Int8Ty, BSI.BitSize - 1); + uint64_t InlineBits = 0; + for (auto Bit : BSI.Bits) + InlineBits |= uint64_t(1) << Bit; + if (InlineBits == 0) + TIL.TheKind = TypeTestResolution::Unsat; + else + TIL.InlineBits = ConstantInt::get( + (BSI.BitSize <= 32) ? Int32Ty : Int64Ty, InlineBits); + } else { + TIL.TheKind = TypeTestResolution::ByteArray; + TIL.SizeM1BitWidth = (BSI.BitSize <= 128) ? 7 : 32; + TIL.SizeM1 = ConstantInt::get((BSI.BitSize <= 128) ? Int8Ty : Int32Ty, + BSI.BitSize - 1); + ++NumByteArraysCreated; + ByteArrayInfo *BAI = createByteArray(BSI); + TIL.TheByteArray = BAI->ByteArray; + TIL.BitMask = BAI->MaskGlobal; + } // Lower each call to llvm.type.test for this type identifier. for (CallInst *CI : TypeTestCallSites[TypeId]) { ++NumTypeTestCallsLowered; - Value *Lowered = - lowerBitSetCall(CI, BSI, BAI, CombinedGlobalIntAddr, GlobalLayout); + Value *Lowered = lowerTypeTestCall(TypeId, CI, TIL); CI->replaceAllUsesWith(Lowered); CI->eraseFromParent(); } } } -void LowerTypeTests::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) { +void LowerTypeTestsModule::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) { if (Type->getNumOperands() != 2) - report_fatal_error( - "All operands of type metadata must have 2 elements"); + report_fatal_error("All operands of type metadata must have 2 elements"); if (GO->isThreadLocal()) report_fatal_error("Bit set element may not be thread-local"); @@ -610,60 +770,172 @@ void LowerTypeTests::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) { } static const unsigned kX86JumpTableEntrySize = 8; +static const unsigned kARMJumpTableEntrySize = 4; + +unsigned LowerTypeTestsModule::getJumpTableEntrySize() { + switch (Arch) { + case Triple::x86: + case Triple::x86_64: + return kX86JumpTableEntrySize; + case Triple::arm: + case Triple::thumb: + case Triple::aarch64: + return kARMJumpTableEntrySize; + default: + report_fatal_error("Unsupported architecture for jump tables"); + } +} -unsigned LowerTypeTests::getJumpTableEntrySize() { - if (Arch != Triple::x86 && Arch != Triple::x86_64) +// Create a jump table entry for the target. This consists of an instruction +// sequence containing a relative branch to Dest. Appends inline asm text, +// constraints and arguments to AsmOS, ConstraintOS and AsmArgs. +void LowerTypeTestsModule::createJumpTableEntry( + raw_ostream &AsmOS, raw_ostream &ConstraintOS, + SmallVectorImpl<Value *> &AsmArgs, Function *Dest) { + unsigned ArgIndex = AsmArgs.size(); + + if (Arch == Triple::x86 || Arch == Triple::x86_64) { + AsmOS << "jmp ${" << ArgIndex << ":c}@plt\n"; + AsmOS << "int3\nint3\nint3\n"; + } else if (Arch == Triple::arm || Arch == Triple::aarch64) { + AsmOS << "b $" << ArgIndex << "\n"; + } else if (Arch == Triple::thumb) { + AsmOS << "b.w $" << ArgIndex << "\n"; + } else { report_fatal_error("Unsupported architecture for jump tables"); + } - return kX86JumpTableEntrySize; + ConstraintOS << (ArgIndex > 0 ? ",s" : "s"); + AsmArgs.push_back(Dest); } -// Create a constant representing a jump table entry for the target. This -// consists of an instruction sequence containing a relative branch to Dest. The -// constant will be laid out at address Src+(Len*Distance) where Len is the -// target-specific jump table entry size. -Constant *LowerTypeTests::createJumpTableEntry(GlobalObject *Src, - Function *Dest, - unsigned Distance) { - if (Arch != Triple::x86 && Arch != Triple::x86_64) - report_fatal_error("Unsupported architecture for jump tables"); +Type *LowerTypeTestsModule::getJumpTableEntryType() { + return ArrayType::get(Int8Ty, getJumpTableEntrySize()); +} - const unsigned kJmpPCRel32Code = 0xe9; - const unsigned kInt3Code = 0xcc; +/// Given a disjoint set of type identifiers and functions, build the bit sets +/// and lower the llvm.type.test calls, architecture dependently. +void LowerTypeTestsModule::buildBitSetsFromFunctions( + ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) { + if (Arch == Triple::x86 || Arch == Triple::x86_64 || Arch == Triple::arm || + Arch == Triple::thumb || Arch == Triple::aarch64) + buildBitSetsFromFunctionsNative(TypeIds, Functions); + else if (Arch == Triple::wasm32 || Arch == Triple::wasm64) + buildBitSetsFromFunctionsWASM(TypeIds, Functions); + else + report_fatal_error("Unsupported architecture for jump tables"); +} - ConstantInt *Jmp = ConstantInt::get(Int8Ty, kJmpPCRel32Code); +void LowerTypeTestsModule::moveInitializerToModuleConstructor( + GlobalVariable *GV) { + if (WeakInitializerFn == nullptr) { + WeakInitializerFn = Function::Create( + FunctionType::get(Type::getVoidTy(M.getContext()), + /* IsVarArg */ false), + GlobalValue::InternalLinkage, "__cfi_global_var_init", &M); + BasicBlock *BB = + BasicBlock::Create(M.getContext(), "entry", WeakInitializerFn); + ReturnInst::Create(M.getContext(), BB); + WeakInitializerFn->setSection( + ObjectFormat == Triple::MachO + ? "__TEXT,__StaticInit,regular,pure_instructions" + : ".text.startup"); + // This code is equivalent to relocation application, and should run at the + // earliest possible time (i.e. with the highest priority). + appendToGlobalCtors(M, WeakInitializerFn, /* Priority */ 0); + } - // Build a constant representing the displacement between the constant's - // address and Dest. This will resolve to a PC32 relocation referring to Dest. - Constant *DestInt = ConstantExpr::getPtrToInt(Dest, IntPtrTy); - Constant *SrcInt = ConstantExpr::getPtrToInt(Src, IntPtrTy); - Constant *Disp = ConstantExpr::getSub(DestInt, SrcInt); - ConstantInt *DispOffset = - ConstantInt::get(IntPtrTy, Distance * kX86JumpTableEntrySize + 5); - Constant *OffsetedDisp = ConstantExpr::getSub(Disp, DispOffset); - OffsetedDisp = ConstantExpr::getTruncOrBitCast(OffsetedDisp, Int32Ty); + IRBuilder<> IRB(WeakInitializerFn->getEntryBlock().getTerminator()); + GV->setConstant(false); + IRB.CreateAlignedStore(GV->getInitializer(), GV, GV->getAlignment()); + GV->setInitializer(Constant::getNullValue(GV->getValueType())); +} - ConstantInt *Int3 = ConstantInt::get(Int8Ty, kInt3Code); +void LowerTypeTestsModule::findGlobalVariableUsersOf( + Constant *C, SmallSetVector<GlobalVariable *, 8> &Out) { + for (auto *U : C->users()){ + if (auto *GV = dyn_cast<GlobalVariable>(U)) + Out.insert(GV); + else if (auto *C2 = dyn_cast<Constant>(U)) + findGlobalVariableUsersOf(C2, Out); + } +} - Constant *Fields[] = { - Jmp, OffsetedDisp, Int3, Int3, Int3, - }; - return ConstantStruct::getAnon(Fields, /*Packed=*/true); +// Replace all uses of F with (F ? JT : 0). +void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr( + Function *F, Constant *JT) { + // The target expression can not appear in a constant initializer on most + // (all?) targets. Switch to a runtime initializer. + SmallSetVector<GlobalVariable *, 8> GlobalVarUsers; + findGlobalVariableUsersOf(F, GlobalVarUsers); + for (auto GV : GlobalVarUsers) + moveInitializerToModuleConstructor(GV); + + // Can not RAUW F with an expression that uses F. Replace with a temporary + // placeholder first. + Function *PlaceholderFn = + Function::Create(cast<FunctionType>(F->getValueType()), + GlobalValue::ExternalWeakLinkage, "", &M); + F->replaceAllUsesWith(PlaceholderFn); + + Constant *Target = ConstantExpr::getSelect( + ConstantExpr::getICmp(CmpInst::ICMP_NE, F, + Constant::getNullValue(F->getType())), + JT, Constant::getNullValue(F->getType())); + PlaceholderFn->replaceAllUsesWith(Target); + PlaceholderFn->eraseFromParent(); } -Type *LowerTypeTests::getJumpTableEntryType() { - if (Arch != Triple::x86 && Arch != Triple::x86_64) - report_fatal_error("Unsupported architecture for jump tables"); +void LowerTypeTestsModule::createJumpTable( + Function *F, ArrayRef<GlobalTypeMember *> Functions) { + std::string AsmStr, ConstraintStr; + raw_string_ostream AsmOS(AsmStr), ConstraintOS(ConstraintStr); + SmallVector<Value *, 16> AsmArgs; + AsmArgs.reserve(Functions.size() * 2); - return StructType::get(M->getContext(), - {Int8Ty, Int32Ty, Int8Ty, Int8Ty, Int8Ty}, - /*Packed=*/true); + for (unsigned I = 0; I != Functions.size(); ++I) + createJumpTableEntry(AsmOS, ConstraintOS, AsmArgs, + cast<Function>(Functions[I]->getGlobal())); + + // Try to emit the jump table at the end of the text segment. + // Jump table must come after __cfi_check in the cross-dso mode. + // FIXME: this magic section name seems to do the trick. + F->setSection(ObjectFormat == Triple::MachO + ? "__TEXT,__text,regular,pure_instructions" + : ".text.cfi"); + // Align the whole table by entry size. + F->setAlignment(getJumpTableEntrySize()); + // Skip prologue. + // Disabled on win32 due to https://llvm.org/bugs/show_bug.cgi?id=28641#c3. + // Luckily, this function does not get any prologue even without the + // attribute. + if (OS != Triple::Win32) + F->addFnAttr(llvm::Attribute::Naked); + // Thumb jump table assembly needs Thumb2. The following attribute is added by + // Clang for -march=armv7. + if (Arch == Triple::thumb) + F->addFnAttr("target-cpu", "cortex-a8"); + + BasicBlock *BB = BasicBlock::Create(M.getContext(), "entry", F); + IRBuilder<> IRB(BB); + + SmallVector<Type *, 16> ArgTypes; + ArgTypes.reserve(AsmArgs.size()); + for (const auto &Arg : AsmArgs) + ArgTypes.push_back(Arg->getType()); + InlineAsm *JumpTableAsm = + InlineAsm::get(FunctionType::get(IRB.getVoidTy(), ArgTypes, false), + AsmOS.str(), ConstraintOS.str(), + /*hasSideEffects=*/true); + + IRB.CreateCall(JumpTableAsm, AsmArgs); + IRB.CreateUnreachable(); } /// Given a disjoint set of type identifiers and functions, build a jump table /// for the functions, build the bit sets and lower the llvm.type.test calls. -void LowerTypeTests::buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds, - ArrayRef<Function *> Functions) { +void LowerTypeTestsModule::buildBitSetsFromFunctionsNative( + ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) { // Unlike the global bitset builder, the function bitset builder cannot // re-arrange functions in a particular order and base its calculations on the // layout of the functions' entry points, as we have no idea how large a @@ -697,39 +969,35 @@ void LowerTypeTests::buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds, // mov h, %ecx // ret // - // To create a jump table for these functions, we instruct the LLVM code - // generator to output a jump table in the .text section. This is done by - // representing the instructions in the jump table as an LLVM constant and - // placing them in a global variable in the .text section. The end result will - // (conceptually) look like this: + // We output the jump table as module-level inline asm string. The end result + // will (conceptually) look like this: // - // f: - // jmp .Ltmp0 ; 5 bytes + // f = .cfi.jumptable + // g = .cfi.jumptable + 4 + // h = .cfi.jumptable + 8 + // .cfi.jumptable: + // jmp f.cfi ; 5 bytes // int3 ; 1 byte // int3 ; 1 byte // int3 ; 1 byte - // - // g: - // jmp .Ltmp1 ; 5 bytes + // jmp g.cfi ; 5 bytes // int3 ; 1 byte // int3 ; 1 byte // int3 ; 1 byte - // - // h: - // jmp .Ltmp2 ; 5 bytes + // jmp h.cfi ; 5 bytes // int3 ; 1 byte // int3 ; 1 byte // int3 ; 1 byte // - // .Ltmp0: + // f.cfi: // mov 0, %eax // ret // - // .Ltmp1: + // g.cfi: // mov 1, %eax // ret // - // .Ltmp2: + // h.cfi: // mov 2, %eax // ret // @@ -743,60 +1011,101 @@ void LowerTypeTests::buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds, // normal case the check can be carried out using the same kind of simple // arithmetic that we normally use for globals. + // FIXME: find a better way to represent the jumptable in the IR. + assert(!Functions.empty()); // Build a simple layout based on the regular layout of jump tables. - DenseMap<GlobalObject *, uint64_t> GlobalLayout; + DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout; unsigned EntrySize = getJumpTableEntrySize(); for (unsigned I = 0; I != Functions.size(); ++I) GlobalLayout[Functions[I]] = I * EntrySize; - // Create a constant to hold the jump table. + Function *JumpTableFn = + Function::Create(FunctionType::get(Type::getVoidTy(M.getContext()), + /* IsVarArg */ false), + GlobalValue::PrivateLinkage, ".cfi.jumptable", &M); ArrayType *JumpTableType = ArrayType::get(getJumpTableEntryType(), Functions.size()); - auto JumpTable = new GlobalVariable(*M, JumpTableType, - /*isConstant=*/true, - GlobalValue::PrivateLinkage, nullptr); - JumpTable->setSection(ObjectFormat == Triple::MachO - ? "__TEXT,__text,regular,pure_instructions" - : ".text"); + auto JumpTable = + ConstantExpr::getPointerCast(JumpTableFn, JumpTableType->getPointerTo(0)); + lowerTypeTestCalls(TypeIds, JumpTable, GlobalLayout); // Build aliases pointing to offsets into the jump table, and replace // references to the original functions with references to the aliases. for (unsigned I = 0; I != Functions.size(); ++I) { + Function *F = cast<Function>(Functions[I]->getGlobal()); + Constant *CombinedGlobalElemPtr = ConstantExpr::getBitCast( - ConstantExpr::getGetElementPtr( + ConstantExpr::getInBoundsGetElementPtr( JumpTableType, JumpTable, ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0), ConstantInt::get(IntPtrTy, I)}), - Functions[I]->getType()); - if (LinkerSubsectionsViaSymbols || Functions[I]->isDeclarationForLinker()) { - Functions[I]->replaceAllUsesWith(CombinedGlobalElemPtr); + F->getType()); + if (LinkerSubsectionsViaSymbols || F->isDeclarationForLinker()) { + + if (F->isWeakForLinker()) + replaceWeakDeclarationWithJumpTablePtr(F, CombinedGlobalElemPtr); + else + F->replaceAllUsesWith(CombinedGlobalElemPtr); } else { - assert(Functions[I]->getType()->getAddressSpace() == 0); - GlobalAlias *GAlias = GlobalAlias::create(Functions[I]->getValueType(), 0, - Functions[I]->getLinkage(), "", - CombinedGlobalElemPtr, M); - GAlias->setVisibility(Functions[I]->getVisibility()); - GAlias->takeName(Functions[I]); - Functions[I]->replaceAllUsesWith(GAlias); + assert(F->getType()->getAddressSpace() == 0); + + GlobalAlias *FAlias = GlobalAlias::create(F->getValueType(), 0, + F->getLinkage(), "", + CombinedGlobalElemPtr, &M); + FAlias->setVisibility(F->getVisibility()); + FAlias->takeName(F); + if (FAlias->hasName()) + F->setName(FAlias->getName() + ".cfi"); + F->replaceAllUsesWith(FAlias); } - if (!Functions[I]->isDeclarationForLinker()) - Functions[I]->setLinkage(GlobalValue::PrivateLinkage); + if (!F->isDeclarationForLinker()) + F->setLinkage(GlobalValue::InternalLinkage); } - // Build and set the jump table's initializer. - std::vector<Constant *> JumpTableEntries; - for (unsigned I = 0; I != Functions.size(); ++I) - JumpTableEntries.push_back( - createJumpTableEntry(JumpTable, Functions[I], I)); - JumpTable->setInitializer( - ConstantArray::get(JumpTableType, JumpTableEntries)); + createJumpTable(JumpTableFn, Functions); +} + +/// Assign a dummy layout using an incrementing counter, tag each function +/// with its index represented as metadata, and lower each type test to an +/// integer range comparison. During generation of the indirect function call +/// table in the backend, it will assign the given indexes. +/// Note: Dynamic linking is not supported, as the WebAssembly ABI has not yet +/// been finalized. +void LowerTypeTestsModule::buildBitSetsFromFunctionsWASM( + ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) { + assert(!Functions.empty()); + + // Build consecutive monotonic integer ranges for each call target set + DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout; + + for (GlobalTypeMember *GTM : Functions) { + Function *F = cast<Function>(GTM->getGlobal()); + + // Skip functions that are not address taken, to avoid bloating the table + if (!F->hasAddressTaken()) + continue; + + // Store metadata with the index for each function + MDNode *MD = MDNode::get(F->getContext(), + ArrayRef<Metadata *>(ConstantAsMetadata::get( + ConstantInt::get(Int64Ty, IndirectIndex)))); + F->setMetadata("wasm.index", MD); + + // Assign the counter value + GlobalLayout[GTM] = IndirectIndex++; + } + + // The indirect function table index space starts at zero, so pass a NULL + // pointer as the subtracted "jump table" offset. + lowerTypeTestCalls(TypeIds, ConstantPointerNull::get(Int32PtrTy), + GlobalLayout); } -void LowerTypeTests::buildBitSetsFromDisjointSet( - ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalObject *> Globals) { +void LowerTypeTestsModule::buildBitSetsFromDisjointSet( + ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Globals) { llvm::DenseMap<Metadata *, uint64_t> TypeIdIndices; for (unsigned I = 0; I != TypeIds.size(); ++I) TypeIdIndices[TypeIds[I]] = I; @@ -804,12 +1113,9 @@ void LowerTypeTests::buildBitSetsFromDisjointSet( // For each type identifier, build a set of indices that refer to members of // the type identifier. std::vector<std::set<uint64_t>> TypeMembers(TypeIds.size()); - SmallVector<MDNode *, 2> Types; unsigned GlobalIndex = 0; - for (GlobalObject *GO : Globals) { - Types.clear(); - GO->getMetadata(LLVMContext::MD_type, Types); - for (MDNode *Type : Types) { + for (GlobalTypeMember *GTM : Globals) { + for (MDNode *Type : GTM->types()) { // Type = { offset, type identifier } unsigned TypeIdIndex = TypeIdIndices[Type->getOperand(1)]; TypeMembers[TypeIdIndex].insert(GlobalIndex); @@ -833,32 +1139,32 @@ void LowerTypeTests::buildBitSetsFromDisjointSet( GLB.addFragment(MemSet); // Build the bitsets from this disjoint set. - if (Globals.empty() || isa<GlobalVariable>(Globals[0])) { + if (Globals.empty() || isa<GlobalVariable>(Globals[0]->getGlobal())) { // Build a vector of global variables with the computed layout. - std::vector<GlobalVariable *> OrderedGVs(Globals.size()); + std::vector<GlobalTypeMember *> OrderedGVs(Globals.size()); auto OGI = OrderedGVs.begin(); for (auto &&F : GLB.Fragments) { for (auto &&Offset : F) { - auto GV = dyn_cast<GlobalVariable>(Globals[Offset]); + auto GV = dyn_cast<GlobalVariable>(Globals[Offset]->getGlobal()); if (!GV) report_fatal_error("Type identifier may not contain both global " "variables and functions"); - *OGI++ = GV; + *OGI++ = Globals[Offset]; } } buildBitSetsFromGlobalVariables(TypeIds, OrderedGVs); } else { // Build a vector of functions with the computed layout. - std::vector<Function *> OrderedFns(Globals.size()); + std::vector<GlobalTypeMember *> OrderedFns(Globals.size()); auto OFI = OrderedFns.begin(); for (auto &&F : GLB.Fragments) { for (auto &&Offset : F) { - auto Fn = dyn_cast<Function>(Globals[Offset]); + auto Fn = dyn_cast<Function>(Globals[Offset]->getGlobal()); if (!Fn) report_fatal_error("Type identifier may not contain both global " "variables and functions"); - *OFI++ = Fn; + *OFI++ = Globals[Offset]; } } @@ -867,31 +1173,92 @@ void LowerTypeTests::buildBitSetsFromDisjointSet( } /// Lower all type tests in this module. -bool LowerTypeTests::lower() { +LowerTypeTestsModule::LowerTypeTestsModule(Module &M, SummaryAction Action, + ModuleSummaryIndex *Summary) + : M(M), Action(Action), Summary(Summary) { + // FIXME: Use these fields. + (void)this->Action; + (void)this->Summary; + + Triple TargetTriple(M.getTargetTriple()); + LinkerSubsectionsViaSymbols = TargetTriple.isMacOSX(); + Arch = TargetTriple.getArch(); + OS = TargetTriple.getOS(); + ObjectFormat = TargetTriple.getObjectFormat(); +} + +bool LowerTypeTestsModule::runForTesting(Module &M) { + ModuleSummaryIndex Summary; + + // Handle the command-line summary arguments. This code is for testing + // purposes only, so we handle errors directly. + if (!ClReadSummary.empty()) { + ExitOnError ExitOnErr("-lowertypetests-read-summary: " + ClReadSummary + + ": "); + auto ReadSummaryFile = + ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary))); + + yaml::Input In(ReadSummaryFile->getBuffer()); + In >> Summary; + ExitOnErr(errorCodeToError(In.error())); + } + + bool Changed = LowerTypeTestsModule(M, ClSummaryAction, &Summary).lower(); + + if (!ClWriteSummary.empty()) { + ExitOnError ExitOnErr("-lowertypetests-write-summary: " + ClWriteSummary + + ": "); + std::error_code EC; + raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::F_Text); + ExitOnErr(errorCodeToError(EC)); + + yaml::Output Out(OS); + Out << Summary; + } + + return Changed; +} + +bool LowerTypeTestsModule::lower() { Function *TypeTestFunc = - M->getFunction(Intrinsic::getName(Intrinsic::type_test)); + M.getFunction(Intrinsic::getName(Intrinsic::type_test)); if (!TypeTestFunc || TypeTestFunc->use_empty()) return false; // Equivalence class set containing type identifiers and the globals that // reference them. This is used to partition the set of type identifiers in // the module into disjoint sets. - typedef EquivalenceClasses<PointerUnion<GlobalObject *, Metadata *>> + typedef EquivalenceClasses<PointerUnion<GlobalTypeMember *, Metadata *>> GlobalClassesTy; GlobalClassesTy GlobalClasses; - // Verify the type metadata and build a mapping from type identifiers to their - // last observed index in the list of globals. This will be used later to - // deterministically order the list of type identifiers. - llvm::DenseMap<Metadata *, unsigned> TypeIdIndices; + // Verify the type metadata and build a few data structures to let us + // efficiently enumerate the type identifiers associated with a global: + // a list of GlobalTypeMembers (a GlobalObject stored alongside a vector + // of associated type metadata) and a mapping from type identifiers to their + // list of GlobalTypeMembers and last observed index in the list of globals. + // The indices will be used later to deterministically order the list of type + // identifiers. + BumpPtrAllocator Alloc; + struct TIInfo { + unsigned Index; + std::vector<GlobalTypeMember *> RefGlobals; + }; + llvm::DenseMap<Metadata *, TIInfo> TypeIdInfo; unsigned I = 0; SmallVector<MDNode *, 2> Types; - for (GlobalObject &GO : M->global_objects()) { + for (GlobalObject &GO : M.global_objects()) { Types.clear(); GO.getMetadata(LLVMContext::MD_type, Types); + if (Types.empty()) + continue; + + auto *GTM = GlobalTypeMember::create(Alloc, &GO, Types); for (MDNode *Type : Types) { verifyTypeMDNode(&GO, Type); - TypeIdIndices[cast<MDNode>(Type)->getOperand(1)] = ++I; + auto &Info = TypeIdInfo[cast<MDNode>(Type)->getOperand(1)]; + Info.Index = ++I; + Info.RefGlobals.push_back(GTM); } } @@ -900,8 +1267,7 @@ bool LowerTypeTests::lower() { auto BitSetMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1)); if (!BitSetMDVal) - report_fatal_error( - "Second argument of llvm.type.test must be metadata"); + report_fatal_error("Second argument of llvm.type.test must be metadata"); auto BitSet = BitSetMDVal->getMetadata(); // Add the call site to the list of call sites for this type identifier. We @@ -920,14 +1286,9 @@ bool LowerTypeTests::lower() { GlobalClassesTy::member_iterator CurSet = GlobalClasses.findLeader(GCI); // Add the referenced globals to the type identifier's equivalence class. - for (GlobalObject &GO : M->global_objects()) { - Types.clear(); - GO.getMetadata(LLVMContext::MD_type, Types); - for (MDNode *Type : Types) - if (Type->getOperand(1) == BitSet) - CurSet = GlobalClasses.unionSets( - CurSet, GlobalClasses.findLeader(GlobalClasses.insert(&GO))); - } + for (GlobalTypeMember *GTM : TypeIdInfo[BitSet].RefGlobals) + CurSet = GlobalClasses.unionSets( + CurSet, GlobalClasses.findLeader(GlobalClasses.insert(GTM))); } if (GlobalClasses.empty()) @@ -939,14 +1300,15 @@ bool LowerTypeTests::lower() { for (GlobalClassesTy::iterator I = GlobalClasses.begin(), E = GlobalClasses.end(); I != E; ++I) { - if (!I->isLeader()) continue; + if (!I->isLeader()) + continue; ++NumTypeIdDisjointSets; unsigned MaxIndex = 0; for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(I); MI != GlobalClasses.member_end(); ++MI) { if ((*MI).is<Metadata *>()) - MaxIndex = std::max(MaxIndex, TypeIdIndices[MI->get<Metadata *>()]); + MaxIndex = std::max(MaxIndex, TypeIdInfo[MI->get<Metadata *>()].Index); } Sets.emplace_back(I, MaxIndex); } @@ -960,20 +1322,20 @@ bool LowerTypeTests::lower() { for (const auto &S : Sets) { // Build the list of type identifiers in this disjoint set. std::vector<Metadata *> TypeIds; - std::vector<GlobalObject *> Globals; + std::vector<GlobalTypeMember *> Globals; for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(S.first); MI != GlobalClasses.member_end(); ++MI) { if ((*MI).is<Metadata *>()) TypeIds.push_back(MI->get<Metadata *>()); else - Globals.push_back(MI->get<GlobalObject *>()); + Globals.push_back(MI->get<GlobalTypeMember *>()); } // Order type identifiers by global index for determinism. This ordering is // stable as there is a one-to-one mapping between metadata and indices. std::sort(TypeIds.begin(), TypeIds.end(), [&](Metadata *M1, Metadata *M2) { - return TypeIdIndices[M1] < TypeIdIndices[M2]; + return TypeIdInfo[M1].Index < TypeIdInfo[M2].Index; }); // Build bitsets for this disjoint set. @@ -985,35 +1347,10 @@ bool LowerTypeTests::lower() { return true; } -// Initialization helper shared by the old and the new PM. -static void init(LowerTypeTests *LTT, Module &M) { - LTT->M = &M; - const DataLayout &DL = M.getDataLayout(); - Triple TargetTriple(M.getTargetTriple()); - LTT->LinkerSubsectionsViaSymbols = TargetTriple.isMacOSX(); - LTT->Arch = TargetTriple.getArch(); - LTT->ObjectFormat = TargetTriple.getObjectFormat(); - LTT->Int1Ty = Type::getInt1Ty(M.getContext()); - LTT->Int8Ty = Type::getInt8Ty(M.getContext()); - LTT->Int32Ty = Type::getInt32Ty(M.getContext()); - LTT->Int32PtrTy = PointerType::getUnqual(LTT->Int32Ty); - LTT->Int64Ty = Type::getInt64Ty(M.getContext()); - LTT->IntPtrTy = DL.getIntPtrType(M.getContext(), 0); - LTT->TypeTestCallSites.clear(); -} - -bool LowerTypeTests::runOnModule(Module &M) { - if (skipModule(M)) - return false; - init(this, M); - return lower(); -} - PreservedAnalyses LowerTypeTestsPass::run(Module &M, - AnalysisManager<Module> &AM) { - LowerTypeTests Impl; - init(&Impl, M); - bool Changed = Impl.lower(); + ModuleAnalysisManager &AM) { + bool Changed = + LowerTypeTestsModule(M, SummaryAction::None, /*Summary=*/nullptr).lower(); if (!Changed) return PreservedAnalyses::all(); return PreservedAnalyses::none(); diff --git a/contrib/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/contrib/llvm/lib/Transforms/IPO/MergeFunctions.cpp index fe653a7..e0bb0eb 100644 --- a/contrib/llvm/lib/Transforms/IPO/MergeFunctions.cpp +++ b/contrib/llvm/lib/Transforms/IPO/MergeFunctions.cpp @@ -97,11 +97,9 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InlineAsm.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Operator.h" #include "llvm/IR/ValueHandle.h" #include "llvm/IR/ValueMap.h" #include "llvm/Pass.h" @@ -110,6 +108,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Utils/FunctionComparator.h" #include <vector> using namespace llvm; @@ -130,328 +129,6 @@ static cl::opt<unsigned> NumFunctionsForSanityCheck( namespace { -/// GlobalNumberState assigns an integer to each global value in the program, -/// which is used by the comparison routine to order references to globals. This -/// state must be preserved throughout the pass, because Functions and other -/// globals need to maintain their relative order. Globals are assigned a number -/// when they are first visited. This order is deterministic, and so the -/// assigned numbers are as well. When two functions are merged, neither number -/// is updated. If the symbols are weak, this would be incorrect. If they are -/// strong, then one will be replaced at all references to the other, and so -/// direct callsites will now see one or the other symbol, and no update is -/// necessary. Note that if we were guaranteed unique names, we could just -/// compare those, but this would not work for stripped bitcodes or for those -/// few symbols without a name. -class GlobalNumberState { - struct Config : ValueMapConfig<GlobalValue*> { - enum { FollowRAUW = false }; - }; - // Each GlobalValue is mapped to an identifier. The Config ensures when RAUW - // occurs, the mapping does not change. Tracking changes is unnecessary, and - // also problematic for weak symbols (which may be overwritten). - typedef ValueMap<GlobalValue *, uint64_t, Config> ValueNumberMap; - ValueNumberMap GlobalNumbers; - // The next unused serial number to assign to a global. - uint64_t NextNumber; - public: - GlobalNumberState() : GlobalNumbers(), NextNumber(0) {} - uint64_t getNumber(GlobalValue* Global) { - ValueNumberMap::iterator MapIter; - bool Inserted; - std::tie(MapIter, Inserted) = GlobalNumbers.insert({Global, NextNumber}); - if (Inserted) - NextNumber++; - return MapIter->second; - } - void clear() { - GlobalNumbers.clear(); - } -}; - -/// FunctionComparator - Compares two functions to determine whether or not -/// they will generate machine code with the same behaviour. DataLayout is -/// used if available. The comparator always fails conservatively (erring on the -/// side of claiming that two functions are different). -class FunctionComparator { -public: - FunctionComparator(const Function *F1, const Function *F2, - GlobalNumberState* GN) - : FnL(F1), FnR(F2), GlobalNumbers(GN) {} - - /// Test whether the two functions have equivalent behaviour. - int compare(); - /// Hash a function. Equivalent functions will have the same hash, and unequal - /// functions will have different hashes with high probability. - typedef uint64_t FunctionHash; - static FunctionHash functionHash(Function &); - -private: - /// Test whether two basic blocks have equivalent behaviour. - int cmpBasicBlocks(const BasicBlock *BBL, const BasicBlock *BBR) const; - - /// Constants comparison. - /// Its analog to lexicographical comparison between hypothetical numbers - /// of next format: - /// <bitcastability-trait><raw-bit-contents> - /// - /// 1. Bitcastability. - /// Check whether L's type could be losslessly bitcasted to R's type. - /// On this stage method, in case when lossless bitcast is not possible - /// method returns -1 or 1, thus also defining which type is greater in - /// context of bitcastability. - /// Stage 0: If types are equal in terms of cmpTypes, then we can go straight - /// to the contents comparison. - /// If types differ, remember types comparison result and check - /// whether we still can bitcast types. - /// Stage 1: Types that satisfies isFirstClassType conditions are always - /// greater then others. - /// Stage 2: Vector is greater then non-vector. - /// If both types are vectors, then vector with greater bitwidth is - /// greater. - /// If both types are vectors with the same bitwidth, then types - /// are bitcastable, and we can skip other stages, and go to contents - /// comparison. - /// Stage 3: Pointer types are greater than non-pointers. If both types are - /// pointers of the same address space - go to contents comparison. - /// Different address spaces: pointer with greater address space is - /// greater. - /// Stage 4: Types are neither vectors, nor pointers. And they differ. - /// We don't know how to bitcast them. So, we better don't do it, - /// and return types comparison result (so it determines the - /// relationship among constants we don't know how to bitcast). - /// - /// Just for clearance, let's see how the set of constants could look - /// on single dimension axis: - /// - /// [NFCT], [FCT, "others"], [FCT, pointers], [FCT, vectors] - /// Where: NFCT - Not a FirstClassType - /// FCT - FirstClassTyp: - /// - /// 2. Compare raw contents. - /// It ignores types on this stage and only compares bits from L and R. - /// Returns 0, if L and R has equivalent contents. - /// -1 or 1 if values are different. - /// Pretty trivial: - /// 2.1. If contents are numbers, compare numbers. - /// Ints with greater bitwidth are greater. Ints with same bitwidths - /// compared by their contents. - /// 2.2. "And so on". Just to avoid discrepancies with comments - /// perhaps it would be better to read the implementation itself. - /// 3. And again about overall picture. Let's look back at how the ordered set - /// of constants will look like: - /// [NFCT], [FCT, "others"], [FCT, pointers], [FCT, vectors] - /// - /// Now look, what could be inside [FCT, "others"], for example: - /// [FCT, "others"] = - /// [ - /// [double 0.1], [double 1.23], - /// [i32 1], [i32 2], - /// { double 1.0 }, ; StructTyID, NumElements = 1 - /// { i32 1 }, ; StructTyID, NumElements = 1 - /// { double 1, i32 1 }, ; StructTyID, NumElements = 2 - /// { i32 1, double 1 } ; StructTyID, NumElements = 2 - /// ] - /// - /// Let's explain the order. Float numbers will be less than integers, just - /// because of cmpType terms: FloatTyID < IntegerTyID. - /// Floats (with same fltSemantics) are sorted according to their value. - /// Then you can see integers, and they are, like a floats, - /// could be easy sorted among each others. - /// The structures. Structures are grouped at the tail, again because of their - /// TypeID: StructTyID > IntegerTyID > FloatTyID. - /// Structures with greater number of elements are greater. Structures with - /// greater elements going first are greater. - /// The same logic with vectors, arrays and other possible complex types. - /// - /// Bitcastable constants. - /// Let's assume, that some constant, belongs to some group of - /// "so-called-equal" values with different types, and at the same time - /// belongs to another group of constants with equal types - /// and "really" equal values. - /// - /// Now, prove that this is impossible: - /// - /// If constant A with type TyA is bitcastable to B with type TyB, then: - /// 1. All constants with equal types to TyA, are bitcastable to B. Since - /// those should be vectors (if TyA is vector), pointers - /// (if TyA is pointer), or else (if TyA equal to TyB), those types should - /// be equal to TyB. - /// 2. All constants with non-equal, but bitcastable types to TyA, are - /// bitcastable to B. - /// Once again, just because we allow it to vectors and pointers only. - /// This statement could be expanded as below: - /// 2.1. All vectors with equal bitwidth to vector A, has equal bitwidth to - /// vector B, and thus bitcastable to B as well. - /// 2.2. All pointers of the same address space, no matter what they point to, - /// bitcastable. So if C is pointer, it could be bitcasted to A and to B. - /// So any constant equal or bitcastable to A is equal or bitcastable to B. - /// QED. - /// - /// In another words, for pointers and vectors, we ignore top-level type and - /// look at their particular properties (bit-width for vectors, and - /// address space for pointers). - /// If these properties are equal - compare their contents. - int cmpConstants(const Constant *L, const Constant *R) const; - - /// Compares two global values by number. Uses the GlobalNumbersState to - /// identify the same gobals across function calls. - int cmpGlobalValues(GlobalValue *L, GlobalValue *R) const; - - /// Assign or look up previously assigned numbers for the two values, and - /// return whether the numbers are equal. Numbers are assigned in the order - /// visited. - /// Comparison order: - /// Stage 0: Value that is function itself is always greater then others. - /// If left and right values are references to their functions, then - /// they are equal. - /// Stage 1: Constants are greater than non-constants. - /// If both left and right are constants, then the result of - /// cmpConstants is used as cmpValues result. - /// Stage 2: InlineAsm instances are greater than others. If both left and - /// right are InlineAsm instances, InlineAsm* pointers casted to - /// integers and compared as numbers. - /// Stage 3: For all other cases we compare order we meet these values in - /// their functions. If right value was met first during scanning, - /// then left value is greater. - /// In another words, we compare serial numbers, for more details - /// see comments for sn_mapL and sn_mapR. - int cmpValues(const Value *L, const Value *R) const; - - /// Compare two Instructions for equivalence, similar to - /// Instruction::isSameOperationAs. - /// - /// Stages are listed in "most significant stage first" order: - /// On each stage below, we do comparison between some left and right - /// operation parts. If parts are non-equal, we assign parts comparison - /// result to the operation comparison result and exit from method. - /// Otherwise we proceed to the next stage. - /// Stages: - /// 1. Operations opcodes. Compared as numbers. - /// 2. Number of operands. - /// 3. Operation types. Compared with cmpType method. - /// 4. Compare operation subclass optional data as stream of bytes: - /// just convert it to integers and call cmpNumbers. - /// 5. Compare in operation operand types with cmpType in - /// most significant operand first order. - /// 6. Last stage. Check operations for some specific attributes. - /// For example, for Load it would be: - /// 6.1.Load: volatile (as boolean flag) - /// 6.2.Load: alignment (as integer numbers) - /// 6.3.Load: ordering (as underlying enum class value) - /// 6.4.Load: synch-scope (as integer numbers) - /// 6.5.Load: range metadata (as integer ranges) - /// On this stage its better to see the code, since its not more than 10-15 - /// strings for particular instruction, and could change sometimes. - int cmpOperations(const Instruction *L, const Instruction *R) const; - - /// Compare two GEPs for equivalent pointer arithmetic. - /// Parts to be compared for each comparison stage, - /// most significant stage first: - /// 1. Address space. As numbers. - /// 2. Constant offset, (using GEPOperator::accumulateConstantOffset method). - /// 3. Pointer operand type (using cmpType method). - /// 4. Number of operands. - /// 5. Compare operands, using cmpValues method. - int cmpGEPs(const GEPOperator *GEPL, const GEPOperator *GEPR) const; - int cmpGEPs(const GetElementPtrInst *GEPL, - const GetElementPtrInst *GEPR) const { - return cmpGEPs(cast<GEPOperator>(GEPL), cast<GEPOperator>(GEPR)); - } - - /// cmpType - compares two types, - /// defines total ordering among the types set. - /// - /// Return values: - /// 0 if types are equal, - /// -1 if Left is less than Right, - /// +1 if Left is greater than Right. - /// - /// Description: - /// Comparison is broken onto stages. Like in lexicographical comparison - /// stage coming first has higher priority. - /// On each explanation stage keep in mind total ordering properties. - /// - /// 0. Before comparison we coerce pointer types of 0 address space to - /// integer. - /// We also don't bother with same type at left and right, so - /// just return 0 in this case. - /// - /// 1. If types are of different kind (different type IDs). - /// Return result of type IDs comparison, treating them as numbers. - /// 2. If types are integers, check that they have the same width. If they - /// are vectors, check that they have the same count and subtype. - /// 3. Types have the same ID, so check whether they are one of: - /// * Void - /// * Float - /// * Double - /// * X86_FP80 - /// * FP128 - /// * PPC_FP128 - /// * Label - /// * Metadata - /// We can treat these types as equal whenever their IDs are same. - /// 4. If Left and Right are pointers, return result of address space - /// comparison (numbers comparison). We can treat pointer types of same - /// address space as equal. - /// 5. If types are complex. - /// Then both Left and Right are to be expanded and their element types will - /// be checked with the same way. If we get Res != 0 on some stage, return it. - /// Otherwise return 0. - /// 6. For all other cases put llvm_unreachable. - int cmpTypes(Type *TyL, Type *TyR) const; - - int cmpNumbers(uint64_t L, uint64_t R) const; - int cmpOrderings(AtomicOrdering L, AtomicOrdering R) const; - int cmpAPInts(const APInt &L, const APInt &R) const; - int cmpAPFloats(const APFloat &L, const APFloat &R) const; - int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const; - int cmpMem(StringRef L, StringRef R) const; - int cmpAttrs(const AttributeSet L, const AttributeSet R) const; - int cmpRangeMetadata(const MDNode *L, const MDNode *R) const; - int cmpOperandBundlesSchema(const Instruction *L, const Instruction *R) const; - - // The two functions undergoing comparison. - const Function *FnL, *FnR; - - /// Assign serial numbers to values from left function, and values from - /// right function. - /// Explanation: - /// Being comparing functions we need to compare values we meet at left and - /// right sides. - /// Its easy to sort things out for external values. It just should be - /// the same value at left and right. - /// But for local values (those were introduced inside function body) - /// we have to ensure they were introduced at exactly the same place, - /// and plays the same role. - /// Let's assign serial number to each value when we meet it first time. - /// Values that were met at same place will be with same serial numbers. - /// In this case it would be good to explain few points about values assigned - /// to BBs and other ways of implementation (see below). - /// - /// 1. Safety of BB reordering. - /// It's safe to change the order of BasicBlocks in function. - /// Relationship with other functions and serial numbering will not be - /// changed in this case. - /// As follows from FunctionComparator::compare(), we do CFG walk: we start - /// from the entry, and then take each terminator. So it doesn't matter how in - /// fact BBs are ordered in function. And since cmpValues are called during - /// this walk, the numbering depends only on how BBs located inside the CFG. - /// So the answer is - yes. We will get the same numbering. - /// - /// 2. Impossibility to use dominance properties of values. - /// If we compare two instruction operands: first is usage of local - /// variable AL from function FL, and second is usage of local variable AR - /// from FR, we could compare their origins and check whether they are - /// defined at the same place. - /// But, we are still not able to compare operands of PHI nodes, since those - /// could be operands from further BBs we didn't scan yet. - /// So it's impossible to use dominance properties in general. - mutable DenseMap<const Value*, int> sn_mapL, sn_mapR; - - // The global state we will use - GlobalNumberState* GlobalNumbers; -}; - class FunctionNode { mutable AssertingVH<Function> F; FunctionComparator::FunctionHash Hash; @@ -470,898 +147,6 @@ public: void release() { F = nullptr; } }; -} // end anonymous namespace - -int FunctionComparator::cmpNumbers(uint64_t L, uint64_t R) const { - if (L < R) return -1; - if (L > R) return 1; - return 0; -} - -int FunctionComparator::cmpOrderings(AtomicOrdering L, AtomicOrdering R) const { - if ((int)L < (int)R) return -1; - if ((int)L > (int)R) return 1; - return 0; -} - -int FunctionComparator::cmpAPInts(const APInt &L, const APInt &R) const { - if (int Res = cmpNumbers(L.getBitWidth(), R.getBitWidth())) - return Res; - if (L.ugt(R)) return 1; - if (R.ugt(L)) return -1; - return 0; -} - -int FunctionComparator::cmpAPFloats(const APFloat &L, const APFloat &R) const { - // Floats are ordered first by semantics (i.e. float, double, half, etc.), - // then by value interpreted as a bitstring (aka APInt). - const fltSemantics &SL = L.getSemantics(), &SR = R.getSemantics(); - if (int Res = cmpNumbers(APFloat::semanticsPrecision(SL), - APFloat::semanticsPrecision(SR))) - return Res; - if (int Res = cmpNumbers(APFloat::semanticsMaxExponent(SL), - APFloat::semanticsMaxExponent(SR))) - return Res; - if (int Res = cmpNumbers(APFloat::semanticsMinExponent(SL), - APFloat::semanticsMinExponent(SR))) - return Res; - if (int Res = cmpNumbers(APFloat::semanticsSizeInBits(SL), - APFloat::semanticsSizeInBits(SR))) - return Res; - return cmpAPInts(L.bitcastToAPInt(), R.bitcastToAPInt()); -} - -int FunctionComparator::cmpMem(StringRef L, StringRef R) const { - // Prevent heavy comparison, compare sizes first. - if (int Res = cmpNumbers(L.size(), R.size())) - return Res; - - // Compare strings lexicographically only when it is necessary: only when - // strings are equal in size. - return L.compare(R); -} - -int FunctionComparator::cmpAttrs(const AttributeSet L, - const AttributeSet R) const { - if (int Res = cmpNumbers(L.getNumSlots(), R.getNumSlots())) - return Res; - - for (unsigned i = 0, e = L.getNumSlots(); i != e; ++i) { - AttributeSet::iterator LI = L.begin(i), LE = L.end(i), RI = R.begin(i), - RE = R.end(i); - for (; LI != LE && RI != RE; ++LI, ++RI) { - Attribute LA = *LI; - Attribute RA = *RI; - if (LA < RA) - return -1; - if (RA < LA) - return 1; - } - if (LI != LE) - return 1; - if (RI != RE) - return -1; - } - return 0; -} - -int FunctionComparator::cmpRangeMetadata(const MDNode *L, - const MDNode *R) const { - if (L == R) - return 0; - if (!L) - return -1; - if (!R) - return 1; - // Range metadata is a sequence of numbers. Make sure they are the same - // sequence. - // TODO: Note that as this is metadata, it is possible to drop and/or merge - // this data when considering functions to merge. Thus this comparison would - // return 0 (i.e. equivalent), but merging would become more complicated - // because the ranges would need to be unioned. It is not likely that - // functions differ ONLY in this metadata if they are actually the same - // function semantically. - if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands())) - return Res; - for (size_t I = 0; I < L->getNumOperands(); ++I) { - ConstantInt *LLow = mdconst::extract<ConstantInt>(L->getOperand(I)); - ConstantInt *RLow = mdconst::extract<ConstantInt>(R->getOperand(I)); - if (int Res = cmpAPInts(LLow->getValue(), RLow->getValue())) - return Res; - } - return 0; -} - -int FunctionComparator::cmpOperandBundlesSchema(const Instruction *L, - const Instruction *R) const { - ImmutableCallSite LCS(L); - ImmutableCallSite RCS(R); - - assert(LCS && RCS && "Must be calls or invokes!"); - assert(LCS.isCall() == RCS.isCall() && "Can't compare otherwise!"); - - if (int Res = - cmpNumbers(LCS.getNumOperandBundles(), RCS.getNumOperandBundles())) - return Res; - - for (unsigned i = 0, e = LCS.getNumOperandBundles(); i != e; ++i) { - auto OBL = LCS.getOperandBundleAt(i); - auto OBR = RCS.getOperandBundleAt(i); - - if (int Res = OBL.getTagName().compare(OBR.getTagName())) - return Res; - - if (int Res = cmpNumbers(OBL.Inputs.size(), OBR.Inputs.size())) - return Res; - } - - return 0; -} - -/// Constants comparison: -/// 1. Check whether type of L constant could be losslessly bitcasted to R -/// type. -/// 2. Compare constant contents. -/// For more details see declaration comments. -int FunctionComparator::cmpConstants(const Constant *L, - const Constant *R) const { - - Type *TyL = L->getType(); - Type *TyR = R->getType(); - - // Check whether types are bitcastable. This part is just re-factored - // Type::canLosslesslyBitCastTo method, but instead of returning true/false, - // we also pack into result which type is "less" for us. - int TypesRes = cmpTypes(TyL, TyR); - if (TypesRes != 0) { - // Types are different, but check whether we can bitcast them. - if (!TyL->isFirstClassType()) { - if (TyR->isFirstClassType()) - return -1; - // Neither TyL nor TyR are values of first class type. Return the result - // of comparing the types - return TypesRes; - } - if (!TyR->isFirstClassType()) { - if (TyL->isFirstClassType()) - return 1; - return TypesRes; - } - - // Vector -> Vector conversions are always lossless if the two vector types - // have the same size, otherwise not. - unsigned TyLWidth = 0; - unsigned TyRWidth = 0; - - if (auto *VecTyL = dyn_cast<VectorType>(TyL)) - TyLWidth = VecTyL->getBitWidth(); - if (auto *VecTyR = dyn_cast<VectorType>(TyR)) - TyRWidth = VecTyR->getBitWidth(); - - if (TyLWidth != TyRWidth) - return cmpNumbers(TyLWidth, TyRWidth); - - // Zero bit-width means neither TyL nor TyR are vectors. - if (!TyLWidth) { - PointerType *PTyL = dyn_cast<PointerType>(TyL); - PointerType *PTyR = dyn_cast<PointerType>(TyR); - if (PTyL && PTyR) { - unsigned AddrSpaceL = PTyL->getAddressSpace(); - unsigned AddrSpaceR = PTyR->getAddressSpace(); - if (int Res = cmpNumbers(AddrSpaceL, AddrSpaceR)) - return Res; - } - if (PTyL) - return 1; - if (PTyR) - return -1; - - // TyL and TyR aren't vectors, nor pointers. We don't know how to - // bitcast them. - return TypesRes; - } - } - - // OK, types are bitcastable, now check constant contents. - - if (L->isNullValue() && R->isNullValue()) - return TypesRes; - if (L->isNullValue() && !R->isNullValue()) - return 1; - if (!L->isNullValue() && R->isNullValue()) - return -1; - - auto GlobalValueL = const_cast<GlobalValue*>(dyn_cast<GlobalValue>(L)); - auto GlobalValueR = const_cast<GlobalValue*>(dyn_cast<GlobalValue>(R)); - if (GlobalValueL && GlobalValueR) { - return cmpGlobalValues(GlobalValueL, GlobalValueR); - } - - if (int Res = cmpNumbers(L->getValueID(), R->getValueID())) - return Res; - - if (const auto *SeqL = dyn_cast<ConstantDataSequential>(L)) { - const auto *SeqR = cast<ConstantDataSequential>(R); - // This handles ConstantDataArray and ConstantDataVector. Note that we - // compare the two raw data arrays, which might differ depending on the host - // endianness. This isn't a problem though, because the endiness of a module - // will affect the order of the constants, but this order is the same - // for a given input module and host platform. - return cmpMem(SeqL->getRawDataValues(), SeqR->getRawDataValues()); - } - - switch (L->getValueID()) { - case Value::UndefValueVal: - case Value::ConstantTokenNoneVal: - return TypesRes; - case Value::ConstantIntVal: { - const APInt &LInt = cast<ConstantInt>(L)->getValue(); - const APInt &RInt = cast<ConstantInt>(R)->getValue(); - return cmpAPInts(LInt, RInt); - } - case Value::ConstantFPVal: { - const APFloat &LAPF = cast<ConstantFP>(L)->getValueAPF(); - const APFloat &RAPF = cast<ConstantFP>(R)->getValueAPF(); - return cmpAPFloats(LAPF, RAPF); - } - case Value::ConstantArrayVal: { - const ConstantArray *LA = cast<ConstantArray>(L); - const ConstantArray *RA = cast<ConstantArray>(R); - uint64_t NumElementsL = cast<ArrayType>(TyL)->getNumElements(); - uint64_t NumElementsR = cast<ArrayType>(TyR)->getNumElements(); - if (int Res = cmpNumbers(NumElementsL, NumElementsR)) - return Res; - for (uint64_t i = 0; i < NumElementsL; ++i) { - if (int Res = cmpConstants(cast<Constant>(LA->getOperand(i)), - cast<Constant>(RA->getOperand(i)))) - return Res; - } - return 0; - } - case Value::ConstantStructVal: { - const ConstantStruct *LS = cast<ConstantStruct>(L); - const ConstantStruct *RS = cast<ConstantStruct>(R); - unsigned NumElementsL = cast<StructType>(TyL)->getNumElements(); - unsigned NumElementsR = cast<StructType>(TyR)->getNumElements(); - if (int Res = cmpNumbers(NumElementsL, NumElementsR)) - return Res; - for (unsigned i = 0; i != NumElementsL; ++i) { - if (int Res = cmpConstants(cast<Constant>(LS->getOperand(i)), - cast<Constant>(RS->getOperand(i)))) - return Res; - } - return 0; - } - case Value::ConstantVectorVal: { - const ConstantVector *LV = cast<ConstantVector>(L); - const ConstantVector *RV = cast<ConstantVector>(R); - unsigned NumElementsL = cast<VectorType>(TyL)->getNumElements(); - unsigned NumElementsR = cast<VectorType>(TyR)->getNumElements(); - if (int Res = cmpNumbers(NumElementsL, NumElementsR)) - return Res; - for (uint64_t i = 0; i < NumElementsL; ++i) { - if (int Res = cmpConstants(cast<Constant>(LV->getOperand(i)), - cast<Constant>(RV->getOperand(i)))) - return Res; - } - return 0; - } - case Value::ConstantExprVal: { - const ConstantExpr *LE = cast<ConstantExpr>(L); - const ConstantExpr *RE = cast<ConstantExpr>(R); - unsigned NumOperandsL = LE->getNumOperands(); - unsigned NumOperandsR = RE->getNumOperands(); - if (int Res = cmpNumbers(NumOperandsL, NumOperandsR)) - return Res; - for (unsigned i = 0; i < NumOperandsL; ++i) { - if (int Res = cmpConstants(cast<Constant>(LE->getOperand(i)), - cast<Constant>(RE->getOperand(i)))) - return Res; - } - return 0; - } - case Value::BlockAddressVal: { - const BlockAddress *LBA = cast<BlockAddress>(L); - const BlockAddress *RBA = cast<BlockAddress>(R); - if (int Res = cmpValues(LBA->getFunction(), RBA->getFunction())) - return Res; - if (LBA->getFunction() == RBA->getFunction()) { - // They are BBs in the same function. Order by which comes first in the - // BB order of the function. This order is deterministic. - Function* F = LBA->getFunction(); - BasicBlock *LBB = LBA->getBasicBlock(); - BasicBlock *RBB = RBA->getBasicBlock(); - if (LBB == RBB) - return 0; - for(BasicBlock &BB : F->getBasicBlockList()) { - if (&BB == LBB) { - assert(&BB != RBB); - return -1; - } - if (&BB == RBB) - return 1; - } - llvm_unreachable("Basic Block Address does not point to a basic block in " - "its function."); - return -1; - } else { - // cmpValues said the functions are the same. So because they aren't - // literally the same pointer, they must respectively be the left and - // right functions. - assert(LBA->getFunction() == FnL && RBA->getFunction() == FnR); - // cmpValues will tell us if these are equivalent BasicBlocks, in the - // context of their respective functions. - return cmpValues(LBA->getBasicBlock(), RBA->getBasicBlock()); - } - } - default: // Unknown constant, abort. - DEBUG(dbgs() << "Looking at valueID " << L->getValueID() << "\n"); - llvm_unreachable("Constant ValueID not recognized."); - return -1; - } -} - -int FunctionComparator::cmpGlobalValues(GlobalValue *L, GlobalValue *R) const { - return cmpNumbers(GlobalNumbers->getNumber(L), GlobalNumbers->getNumber(R)); -} - -/// cmpType - compares two types, -/// defines total ordering among the types set. -/// See method declaration comments for more details. -int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const { - PointerType *PTyL = dyn_cast<PointerType>(TyL); - PointerType *PTyR = dyn_cast<PointerType>(TyR); - - const DataLayout &DL = FnL->getParent()->getDataLayout(); - if (PTyL && PTyL->getAddressSpace() == 0) - TyL = DL.getIntPtrType(TyL); - if (PTyR && PTyR->getAddressSpace() == 0) - TyR = DL.getIntPtrType(TyR); - - if (TyL == TyR) - return 0; - - if (int Res = cmpNumbers(TyL->getTypeID(), TyR->getTypeID())) - return Res; - - switch (TyL->getTypeID()) { - default: - llvm_unreachable("Unknown type!"); - // Fall through in Release mode. - case Type::IntegerTyID: - return cmpNumbers(cast<IntegerType>(TyL)->getBitWidth(), - cast<IntegerType>(TyR)->getBitWidth()); - case Type::VectorTyID: { - VectorType *VTyL = cast<VectorType>(TyL), *VTyR = cast<VectorType>(TyR); - if (int Res = cmpNumbers(VTyL->getNumElements(), VTyR->getNumElements())) - return Res; - return cmpTypes(VTyL->getElementType(), VTyR->getElementType()); - } - // TyL == TyR would have returned true earlier, because types are uniqued. - case Type::VoidTyID: - case Type::FloatTyID: - case Type::DoubleTyID: - case Type::X86_FP80TyID: - case Type::FP128TyID: - case Type::PPC_FP128TyID: - case Type::LabelTyID: - case Type::MetadataTyID: - case Type::TokenTyID: - return 0; - - case Type::PointerTyID: { - assert(PTyL && PTyR && "Both types must be pointers here."); - return cmpNumbers(PTyL->getAddressSpace(), PTyR->getAddressSpace()); - } - - case Type::StructTyID: { - StructType *STyL = cast<StructType>(TyL); - StructType *STyR = cast<StructType>(TyR); - if (STyL->getNumElements() != STyR->getNumElements()) - return cmpNumbers(STyL->getNumElements(), STyR->getNumElements()); - - if (STyL->isPacked() != STyR->isPacked()) - return cmpNumbers(STyL->isPacked(), STyR->isPacked()); - - for (unsigned i = 0, e = STyL->getNumElements(); i != e; ++i) { - if (int Res = cmpTypes(STyL->getElementType(i), STyR->getElementType(i))) - return Res; - } - return 0; - } - - case Type::FunctionTyID: { - FunctionType *FTyL = cast<FunctionType>(TyL); - FunctionType *FTyR = cast<FunctionType>(TyR); - if (FTyL->getNumParams() != FTyR->getNumParams()) - return cmpNumbers(FTyL->getNumParams(), FTyR->getNumParams()); - - if (FTyL->isVarArg() != FTyR->isVarArg()) - return cmpNumbers(FTyL->isVarArg(), FTyR->isVarArg()); - - if (int Res = cmpTypes(FTyL->getReturnType(), FTyR->getReturnType())) - return Res; - - for (unsigned i = 0, e = FTyL->getNumParams(); i != e; ++i) { - if (int Res = cmpTypes(FTyL->getParamType(i), FTyR->getParamType(i))) - return Res; - } - return 0; - } - - case Type::ArrayTyID: { - ArrayType *ATyL = cast<ArrayType>(TyL); - ArrayType *ATyR = cast<ArrayType>(TyR); - if (ATyL->getNumElements() != ATyR->getNumElements()) - return cmpNumbers(ATyL->getNumElements(), ATyR->getNumElements()); - return cmpTypes(ATyL->getElementType(), ATyR->getElementType()); - } - } -} - -// Determine whether the two operations are the same except that pointer-to-A -// and pointer-to-B are equivalent. This should be kept in sync with -// Instruction::isSameOperationAs. -// Read method declaration comments for more details. -int FunctionComparator::cmpOperations(const Instruction *L, - const Instruction *R) const { - // Differences from Instruction::isSameOperationAs: - // * replace type comparison with calls to cmpTypes. - // * we test for I->getRawSubclassOptionalData (nuw/nsw/tail) at the top. - // * because of the above, we don't test for the tail bit on calls later on. - if (int Res = cmpNumbers(L->getOpcode(), R->getOpcode())) - return Res; - - if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands())) - return Res; - - if (int Res = cmpTypes(L->getType(), R->getType())) - return Res; - - if (int Res = cmpNumbers(L->getRawSubclassOptionalData(), - R->getRawSubclassOptionalData())) - return Res; - - // We have two instructions of identical opcode and #operands. Check to see - // if all operands are the same type - for (unsigned i = 0, e = L->getNumOperands(); i != e; ++i) { - if (int Res = - cmpTypes(L->getOperand(i)->getType(), R->getOperand(i)->getType())) - return Res; - } - - // Check special state that is a part of some instructions. - if (const AllocaInst *AI = dyn_cast<AllocaInst>(L)) { - if (int Res = cmpTypes(AI->getAllocatedType(), - cast<AllocaInst>(R)->getAllocatedType())) - return Res; - return cmpNumbers(AI->getAlignment(), cast<AllocaInst>(R)->getAlignment()); - } - if (const LoadInst *LI = dyn_cast<LoadInst>(L)) { - if (int Res = cmpNumbers(LI->isVolatile(), cast<LoadInst>(R)->isVolatile())) - return Res; - if (int Res = - cmpNumbers(LI->getAlignment(), cast<LoadInst>(R)->getAlignment())) - return Res; - if (int Res = - cmpOrderings(LI->getOrdering(), cast<LoadInst>(R)->getOrdering())) - return Res; - if (int Res = - cmpNumbers(LI->getSynchScope(), cast<LoadInst>(R)->getSynchScope())) - return Res; - return cmpRangeMetadata(LI->getMetadata(LLVMContext::MD_range), - cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range)); - } - if (const StoreInst *SI = dyn_cast<StoreInst>(L)) { - if (int Res = - cmpNumbers(SI->isVolatile(), cast<StoreInst>(R)->isVolatile())) - return Res; - if (int Res = - cmpNumbers(SI->getAlignment(), cast<StoreInst>(R)->getAlignment())) - return Res; - if (int Res = - cmpOrderings(SI->getOrdering(), cast<StoreInst>(R)->getOrdering())) - return Res; - return cmpNumbers(SI->getSynchScope(), cast<StoreInst>(R)->getSynchScope()); - } - if (const CmpInst *CI = dyn_cast<CmpInst>(L)) - return cmpNumbers(CI->getPredicate(), cast<CmpInst>(R)->getPredicate()); - if (const CallInst *CI = dyn_cast<CallInst>(L)) { - if (int Res = cmpNumbers(CI->getCallingConv(), - cast<CallInst>(R)->getCallingConv())) - return Res; - if (int Res = - cmpAttrs(CI->getAttributes(), cast<CallInst>(R)->getAttributes())) - return Res; - if (int Res = cmpOperandBundlesSchema(CI, R)) - return Res; - return cmpRangeMetadata( - CI->getMetadata(LLVMContext::MD_range), - cast<CallInst>(R)->getMetadata(LLVMContext::MD_range)); - } - if (const InvokeInst *II = dyn_cast<InvokeInst>(L)) { - if (int Res = cmpNumbers(II->getCallingConv(), - cast<InvokeInst>(R)->getCallingConv())) - return Res; - if (int Res = - cmpAttrs(II->getAttributes(), cast<InvokeInst>(R)->getAttributes())) - return Res; - if (int Res = cmpOperandBundlesSchema(II, R)) - return Res; - return cmpRangeMetadata( - II->getMetadata(LLVMContext::MD_range), - cast<InvokeInst>(R)->getMetadata(LLVMContext::MD_range)); - } - if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(L)) { - ArrayRef<unsigned> LIndices = IVI->getIndices(); - ArrayRef<unsigned> RIndices = cast<InsertValueInst>(R)->getIndices(); - if (int Res = cmpNumbers(LIndices.size(), RIndices.size())) - return Res; - for (size_t i = 0, e = LIndices.size(); i != e; ++i) { - if (int Res = cmpNumbers(LIndices[i], RIndices[i])) - return Res; - } - return 0; - } - if (const ExtractValueInst *EVI = dyn_cast<ExtractValueInst>(L)) { - ArrayRef<unsigned> LIndices = EVI->getIndices(); - ArrayRef<unsigned> RIndices = cast<ExtractValueInst>(R)->getIndices(); - if (int Res = cmpNumbers(LIndices.size(), RIndices.size())) - return Res; - for (size_t i = 0, e = LIndices.size(); i != e; ++i) { - if (int Res = cmpNumbers(LIndices[i], RIndices[i])) - return Res; - } - } - if (const FenceInst *FI = dyn_cast<FenceInst>(L)) { - if (int Res = - cmpOrderings(FI->getOrdering(), cast<FenceInst>(R)->getOrdering())) - return Res; - return cmpNumbers(FI->getSynchScope(), cast<FenceInst>(R)->getSynchScope()); - } - if (const AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(L)) { - if (int Res = cmpNumbers(CXI->isVolatile(), - cast<AtomicCmpXchgInst>(R)->isVolatile())) - return Res; - if (int Res = cmpNumbers(CXI->isWeak(), - cast<AtomicCmpXchgInst>(R)->isWeak())) - return Res; - if (int Res = - cmpOrderings(CXI->getSuccessOrdering(), - cast<AtomicCmpXchgInst>(R)->getSuccessOrdering())) - return Res; - if (int Res = - cmpOrderings(CXI->getFailureOrdering(), - cast<AtomicCmpXchgInst>(R)->getFailureOrdering())) - return Res; - return cmpNumbers(CXI->getSynchScope(), - cast<AtomicCmpXchgInst>(R)->getSynchScope()); - } - if (const AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(L)) { - if (int Res = cmpNumbers(RMWI->getOperation(), - cast<AtomicRMWInst>(R)->getOperation())) - return Res; - if (int Res = cmpNumbers(RMWI->isVolatile(), - cast<AtomicRMWInst>(R)->isVolatile())) - return Res; - if (int Res = cmpOrderings(RMWI->getOrdering(), - cast<AtomicRMWInst>(R)->getOrdering())) - return Res; - return cmpNumbers(RMWI->getSynchScope(), - cast<AtomicRMWInst>(R)->getSynchScope()); - } - if (const PHINode *PNL = dyn_cast<PHINode>(L)) { - const PHINode *PNR = cast<PHINode>(R); - // Ensure that in addition to the incoming values being identical - // (checked by the caller of this function), the incoming blocks - // are also identical. - for (unsigned i = 0, e = PNL->getNumIncomingValues(); i != e; ++i) { - if (int Res = - cmpValues(PNL->getIncomingBlock(i), PNR->getIncomingBlock(i))) - return Res; - } - } - return 0; -} - -// Determine whether two GEP operations perform the same underlying arithmetic. -// Read method declaration comments for more details. -int FunctionComparator::cmpGEPs(const GEPOperator *GEPL, - const GEPOperator *GEPR) const { - - unsigned int ASL = GEPL->getPointerAddressSpace(); - unsigned int ASR = GEPR->getPointerAddressSpace(); - - if (int Res = cmpNumbers(ASL, ASR)) - return Res; - - // When we have target data, we can reduce the GEP down to the value in bytes - // added to the address. - const DataLayout &DL = FnL->getParent()->getDataLayout(); - unsigned BitWidth = DL.getPointerSizeInBits(ASL); - APInt OffsetL(BitWidth, 0), OffsetR(BitWidth, 0); - if (GEPL->accumulateConstantOffset(DL, OffsetL) && - GEPR->accumulateConstantOffset(DL, OffsetR)) - return cmpAPInts(OffsetL, OffsetR); - if (int Res = cmpTypes(GEPL->getSourceElementType(), - GEPR->getSourceElementType())) - return Res; - - if (int Res = cmpNumbers(GEPL->getNumOperands(), GEPR->getNumOperands())) - return Res; - - for (unsigned i = 0, e = GEPL->getNumOperands(); i != e; ++i) { - if (int Res = cmpValues(GEPL->getOperand(i), GEPR->getOperand(i))) - return Res; - } - - return 0; -} - -int FunctionComparator::cmpInlineAsm(const InlineAsm *L, - const InlineAsm *R) const { - // InlineAsm's are uniqued. If they are the same pointer, obviously they are - // the same, otherwise compare the fields. - if (L == R) - return 0; - if (int Res = cmpTypes(L->getFunctionType(), R->getFunctionType())) - return Res; - if (int Res = cmpMem(L->getAsmString(), R->getAsmString())) - return Res; - if (int Res = cmpMem(L->getConstraintString(), R->getConstraintString())) - return Res; - if (int Res = cmpNumbers(L->hasSideEffects(), R->hasSideEffects())) - return Res; - if (int Res = cmpNumbers(L->isAlignStack(), R->isAlignStack())) - return Res; - if (int Res = cmpNumbers(L->getDialect(), R->getDialect())) - return Res; - llvm_unreachable("InlineAsm blocks were not uniqued."); - return 0; -} - -/// Compare two values used by the two functions under pair-wise comparison. If -/// this is the first time the values are seen, they're added to the mapping so -/// that we will detect mismatches on next use. -/// See comments in declaration for more details. -int FunctionComparator::cmpValues(const Value *L, const Value *R) const { - // Catch self-reference case. - if (L == FnL) { - if (R == FnR) - return 0; - return -1; - } - if (R == FnR) { - if (L == FnL) - return 0; - return 1; - } - - const Constant *ConstL = dyn_cast<Constant>(L); - const Constant *ConstR = dyn_cast<Constant>(R); - if (ConstL && ConstR) { - if (L == R) - return 0; - return cmpConstants(ConstL, ConstR); - } - - if (ConstL) - return 1; - if (ConstR) - return -1; - - const InlineAsm *InlineAsmL = dyn_cast<InlineAsm>(L); - const InlineAsm *InlineAsmR = dyn_cast<InlineAsm>(R); - - if (InlineAsmL && InlineAsmR) - return cmpInlineAsm(InlineAsmL, InlineAsmR); - if (InlineAsmL) - return 1; - if (InlineAsmR) - return -1; - - auto LeftSN = sn_mapL.insert(std::make_pair(L, sn_mapL.size())), - RightSN = sn_mapR.insert(std::make_pair(R, sn_mapR.size())); - - return cmpNumbers(LeftSN.first->second, RightSN.first->second); -} -// Test whether two basic blocks have equivalent behaviour. -int FunctionComparator::cmpBasicBlocks(const BasicBlock *BBL, - const BasicBlock *BBR) const { - BasicBlock::const_iterator InstL = BBL->begin(), InstLE = BBL->end(); - BasicBlock::const_iterator InstR = BBR->begin(), InstRE = BBR->end(); - - do { - if (int Res = cmpValues(&*InstL, &*InstR)) - return Res; - - const GetElementPtrInst *GEPL = dyn_cast<GetElementPtrInst>(InstL); - const GetElementPtrInst *GEPR = dyn_cast<GetElementPtrInst>(InstR); - - if (GEPL && !GEPR) - return 1; - if (GEPR && !GEPL) - return -1; - - if (GEPL && GEPR) { - if (int Res = - cmpValues(GEPL->getPointerOperand(), GEPR->getPointerOperand())) - return Res; - if (int Res = cmpGEPs(GEPL, GEPR)) - return Res; - } else { - if (int Res = cmpOperations(&*InstL, &*InstR)) - return Res; - assert(InstL->getNumOperands() == InstR->getNumOperands()); - - for (unsigned i = 0, e = InstL->getNumOperands(); i != e; ++i) { - Value *OpL = InstL->getOperand(i); - Value *OpR = InstR->getOperand(i); - if (int Res = cmpValues(OpL, OpR)) - return Res; - // cmpValues should ensure this is true. - assert(cmpTypes(OpL->getType(), OpR->getType()) == 0); - } - } - - ++InstL; - ++InstR; - } while (InstL != InstLE && InstR != InstRE); - - if (InstL != InstLE && InstR == InstRE) - return 1; - if (InstL == InstLE && InstR != InstRE) - return -1; - return 0; -} - -// Test whether the two functions have equivalent behaviour. -int FunctionComparator::compare() { - sn_mapL.clear(); - sn_mapR.clear(); - - if (int Res = cmpAttrs(FnL->getAttributes(), FnR->getAttributes())) - return Res; - - if (int Res = cmpNumbers(FnL->hasGC(), FnR->hasGC())) - return Res; - - if (FnL->hasGC()) { - if (int Res = cmpMem(FnL->getGC(), FnR->getGC())) - return Res; - } - - if (int Res = cmpNumbers(FnL->hasSection(), FnR->hasSection())) - return Res; - - if (FnL->hasSection()) { - if (int Res = cmpMem(FnL->getSection(), FnR->getSection())) - return Res; - } - - if (int Res = cmpNumbers(FnL->isVarArg(), FnR->isVarArg())) - return Res; - - // TODO: if it's internal and only used in direct calls, we could handle this - // case too. - if (int Res = cmpNumbers(FnL->getCallingConv(), FnR->getCallingConv())) - return Res; - - if (int Res = cmpTypes(FnL->getFunctionType(), FnR->getFunctionType())) - return Res; - - assert(FnL->arg_size() == FnR->arg_size() && - "Identically typed functions have different numbers of args!"); - - // Visit the arguments so that they get enumerated in the order they're - // passed in. - for (Function::const_arg_iterator ArgLI = FnL->arg_begin(), - ArgRI = FnR->arg_begin(), - ArgLE = FnL->arg_end(); - ArgLI != ArgLE; ++ArgLI, ++ArgRI) { - if (cmpValues(&*ArgLI, &*ArgRI) != 0) - llvm_unreachable("Arguments repeat!"); - } - - // We do a CFG-ordered walk since the actual ordering of the blocks in the - // linked list is immaterial. Our walk starts at the entry block for both - // functions, then takes each block from each terminator in order. As an - // artifact, this also means that unreachable blocks are ignored. - SmallVector<const BasicBlock *, 8> FnLBBs, FnRBBs; - SmallPtrSet<const BasicBlock *, 32> VisitedBBs; // in terms of F1. - - FnLBBs.push_back(&FnL->getEntryBlock()); - FnRBBs.push_back(&FnR->getEntryBlock()); - - VisitedBBs.insert(FnLBBs[0]); - while (!FnLBBs.empty()) { - const BasicBlock *BBL = FnLBBs.pop_back_val(); - const BasicBlock *BBR = FnRBBs.pop_back_val(); - - if (int Res = cmpValues(BBL, BBR)) - return Res; - - if (int Res = cmpBasicBlocks(BBL, BBR)) - return Res; - - const TerminatorInst *TermL = BBL->getTerminator(); - const TerminatorInst *TermR = BBR->getTerminator(); - - assert(TermL->getNumSuccessors() == TermR->getNumSuccessors()); - for (unsigned i = 0, e = TermL->getNumSuccessors(); i != e; ++i) { - if (!VisitedBBs.insert(TermL->getSuccessor(i)).second) - continue; - - FnLBBs.push_back(TermL->getSuccessor(i)); - FnRBBs.push_back(TermR->getSuccessor(i)); - } - } - return 0; -} - -namespace { -// Accumulate the hash of a sequence of 64-bit integers. This is similar to a -// hash of a sequence of 64bit ints, but the entire input does not need to be -// available at once. This interface is necessary for functionHash because it -// needs to accumulate the hash as the structure of the function is traversed -// without saving these values to an intermediate buffer. This form of hashing -// is not often needed, as usually the object to hash is just read from a -// buffer. -class HashAccumulator64 { - uint64_t Hash; -public: - // Initialize to random constant, so the state isn't zero. - HashAccumulator64() { Hash = 0x6acaa36bef8325c5ULL; } - void add(uint64_t V) { - Hash = llvm::hashing::detail::hash_16_bytes(Hash, V); - } - // No finishing is required, because the entire hash value is used. - uint64_t getHash() { return Hash; } -}; -} // end anonymous namespace - -// A function hash is calculated by considering only the number of arguments and -// whether a function is varargs, the order of basic blocks (given by the -// successors of each basic block in depth first order), and the order of -// opcodes of each instruction within each of these basic blocks. This mirrors -// the strategy compare() uses to compare functions by walking the BBs in depth -// first order and comparing each instruction in sequence. Because this hash -// does not look at the operands, it is insensitive to things such as the -// target of calls and the constants used in the function, which makes it useful -// when possibly merging functions which are the same modulo constants and call -// targets. -FunctionComparator::FunctionHash FunctionComparator::functionHash(Function &F) { - HashAccumulator64 H; - H.add(F.isVarArg()); - H.add(F.arg_size()); - - SmallVector<const BasicBlock *, 8> BBs; - SmallSet<const BasicBlock *, 16> VisitedBBs; - - // Walk the blocks in the same order as FunctionComparator::cmpBasicBlocks(), - // accumulating the hash of the function "structure." (BB and opcode sequence) - BBs.push_back(&F.getEntryBlock()); - VisitedBBs.insert(BBs[0]); - while (!BBs.empty()) { - const BasicBlock *BB = BBs.pop_back_val(); - // This random value acts as a block header, as otherwise the partition of - // opcodes into BBs wouldn't affect the hash, only the order of the opcodes - H.add(45798); - for (auto &Inst : *BB) { - H.add(Inst.getOpcode()); - } - const TerminatorInst *Term = BB->getTerminator(); - for (unsigned i = 0, e = Term->getNumSuccessors(); i != e; ++i) { - if (!VisitedBBs.insert(Term->getSuccessor(i)).second) - continue; - BBs.push_back(Term->getSuccessor(i)); - } - } - return H.getHash(); -} - - -namespace { /// MergeFunctions finds functions which will generate identical machine code, /// by considering all pointer types to be equivalent. Once identified, diff --git a/contrib/llvm/lib/Transforms/IPO/PartialInlining.cpp b/contrib/llvm/lib/Transforms/IPO/PartialInlining.cpp index 49c4417..7ef3fc1 100644 --- a/contrib/llvm/lib/Transforms/IPO/PartialInlining.cpp +++ b/contrib/llvm/lib/Transforms/IPO/PartialInlining.cpp @@ -14,6 +14,9 @@ #include "llvm/Transforms/IPO/PartialInlining.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" @@ -29,161 +32,193 @@ using namespace llvm; STATISTIC(NumPartialInlined, "Number of functions partially inlined"); namespace { +struct PartialInlinerImpl { + PartialInlinerImpl(InlineFunctionInfo IFI) : IFI(IFI) {} + bool run(Module &M); + Function *unswitchFunction(Function *F); + +private: + InlineFunctionInfo IFI; +}; struct PartialInlinerLegacyPass : public ModulePass { static char ID; // Pass identification, replacement for typeid PartialInlinerLegacyPass() : ModulePass(ID) { initializePartialInlinerLegacyPassPass(*PassRegistry::getPassRegistry()); } + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + } bool runOnModule(Module &M) override { if (skipModule(M)) return false; - ModuleAnalysisManager DummyMAM; - auto PA = Impl.run(M, DummyMAM); - return !PA.areAllPreserved(); - } - -private: - PartialInlinerPass Impl; - }; -} - -char PartialInlinerLegacyPass::ID = 0; -INITIALIZE_PASS(PartialInlinerLegacyPass, "partial-inliner", "Partial Inliner", - false, false) -ModulePass *llvm::createPartialInliningPass() { - return new PartialInlinerLegacyPass(); + AssumptionCacheTracker *ACT = &getAnalysis<AssumptionCacheTracker>(); + std::function<AssumptionCache &(Function &)> GetAssumptionCache = + [&ACT](Function &F) -> AssumptionCache & { + return ACT->getAssumptionCache(F); + }; + InlineFunctionInfo IFI(nullptr, &GetAssumptionCache); + return PartialInlinerImpl(IFI).run(M); + } +}; } -Function *PartialInlinerPass::unswitchFunction(Function *F) { +Function *PartialInlinerImpl::unswitchFunction(Function *F) { // First, verify that this function is an unswitching candidate... - BasicBlock *entryBlock = &F->front(); - BranchInst *BR = dyn_cast<BranchInst>(entryBlock->getTerminator()); + BasicBlock *EntryBlock = &F->front(); + BranchInst *BR = dyn_cast<BranchInst>(EntryBlock->getTerminator()); if (!BR || BR->isUnconditional()) return nullptr; - - BasicBlock* returnBlock = nullptr; - BasicBlock* nonReturnBlock = nullptr; - unsigned returnCount = 0; - for (BasicBlock *BB : successors(entryBlock)) { + + BasicBlock *ReturnBlock = nullptr; + BasicBlock *NonReturnBlock = nullptr; + unsigned ReturnCount = 0; + for (BasicBlock *BB : successors(EntryBlock)) { if (isa<ReturnInst>(BB->getTerminator())) { - returnBlock = BB; - returnCount++; + ReturnBlock = BB; + ReturnCount++; } else - nonReturnBlock = BB; + NonReturnBlock = BB; } - - if (returnCount != 1) + + if (ReturnCount != 1) return nullptr; - + // Clone the function, so that we can hack away on it. ValueToValueMapTy VMap; - Function* duplicateFunction = CloneFunction(F, VMap); - duplicateFunction->setLinkage(GlobalValue::InternalLinkage); - BasicBlock* newEntryBlock = cast<BasicBlock>(VMap[entryBlock]); - BasicBlock* newReturnBlock = cast<BasicBlock>(VMap[returnBlock]); - BasicBlock* newNonReturnBlock = cast<BasicBlock>(VMap[nonReturnBlock]); - + Function *DuplicateFunction = CloneFunction(F, VMap); + DuplicateFunction->setLinkage(GlobalValue::InternalLinkage); + BasicBlock *NewEntryBlock = cast<BasicBlock>(VMap[EntryBlock]); + BasicBlock *NewReturnBlock = cast<BasicBlock>(VMap[ReturnBlock]); + BasicBlock *NewNonReturnBlock = cast<BasicBlock>(VMap[NonReturnBlock]); + // Go ahead and update all uses to the duplicate, so that we can just // use the inliner functionality when we're done hacking. - F->replaceAllUsesWith(duplicateFunction); - + F->replaceAllUsesWith(DuplicateFunction); + // Special hackery is needed with PHI nodes that have inputs from more than // one extracted block. For simplicity, just split the PHIs into a two-level // sequence of PHIs, some of which will go in the extracted region, and some // of which will go outside. - BasicBlock* preReturn = newReturnBlock; - newReturnBlock = newReturnBlock->splitBasicBlock( - newReturnBlock->getFirstNonPHI()->getIterator()); - BasicBlock::iterator I = preReturn->begin(); - Instruction *Ins = &newReturnBlock->front(); - while (I != preReturn->end()) { - PHINode* OldPhi = dyn_cast<PHINode>(I); - if (!OldPhi) break; - - PHINode *retPhi = PHINode::Create(OldPhi->getType(), 2, "", Ins); - OldPhi->replaceAllUsesWith(retPhi); - Ins = newReturnBlock->getFirstNonPHI(); - - retPhi->addIncoming(&*I, preReturn); - retPhi->addIncoming(OldPhi->getIncomingValueForBlock(newEntryBlock), - newEntryBlock); - OldPhi->removeIncomingValue(newEntryBlock); - + BasicBlock *PreReturn = NewReturnBlock; + NewReturnBlock = NewReturnBlock->splitBasicBlock( + NewReturnBlock->getFirstNonPHI()->getIterator()); + BasicBlock::iterator I = PreReturn->begin(); + Instruction *Ins = &NewReturnBlock->front(); + while (I != PreReturn->end()) { + PHINode *OldPhi = dyn_cast<PHINode>(I); + if (!OldPhi) + break; + + PHINode *RetPhi = PHINode::Create(OldPhi->getType(), 2, "", Ins); + OldPhi->replaceAllUsesWith(RetPhi); + Ins = NewReturnBlock->getFirstNonPHI(); + + RetPhi->addIncoming(&*I, PreReturn); + RetPhi->addIncoming(OldPhi->getIncomingValueForBlock(NewEntryBlock), + NewEntryBlock); + OldPhi->removeIncomingValue(NewEntryBlock); + ++I; } - newEntryBlock->getTerminator()->replaceUsesOfWith(preReturn, newReturnBlock); - + NewEntryBlock->getTerminator()->replaceUsesOfWith(PreReturn, NewReturnBlock); + // Gather up the blocks that we're going to extract. - std::vector<BasicBlock*> toExtract; - toExtract.push_back(newNonReturnBlock); - for (BasicBlock &BB : *duplicateFunction) - if (&BB != newEntryBlock && &BB != newReturnBlock && - &BB != newNonReturnBlock) - toExtract.push_back(&BB); + std::vector<BasicBlock *> ToExtract; + ToExtract.push_back(NewNonReturnBlock); + for (BasicBlock &BB : *DuplicateFunction) + if (&BB != NewEntryBlock && &BB != NewReturnBlock && + &BB != NewNonReturnBlock) + ToExtract.push_back(&BB); // The CodeExtractor needs a dominator tree. DominatorTree DT; - DT.recalculate(*duplicateFunction); + DT.recalculate(*DuplicateFunction); + + // Manually calculate a BlockFrequencyInfo and BranchProbabilityInfo. + LoopInfo LI(DT); + BranchProbabilityInfo BPI(*DuplicateFunction, LI); + BlockFrequencyInfo BFI(*DuplicateFunction, BPI, LI); // Extract the body of the if. - Function* extractedFunction - = CodeExtractor(toExtract, &DT).extractCodeRegion(); - - InlineFunctionInfo IFI; - + Function *ExtractedFunction = + CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false, &BFI, &BPI) + .extractCodeRegion(); + // Inline the top-level if test into all callers. - std::vector<User *> Users(duplicateFunction->user_begin(), - duplicateFunction->user_end()); + std::vector<User *> Users(DuplicateFunction->user_begin(), + DuplicateFunction->user_end()); for (User *User : Users) if (CallInst *CI = dyn_cast<CallInst>(User)) InlineFunction(CI, IFI); else if (InvokeInst *II = dyn_cast<InvokeInst>(User)) InlineFunction(II, IFI); - + // Ditch the duplicate, since we're done with it, and rewrite all remaining // users (function pointers, etc.) back to the original function. - duplicateFunction->replaceAllUsesWith(F); - duplicateFunction->eraseFromParent(); - + DuplicateFunction->replaceAllUsesWith(F); + DuplicateFunction->eraseFromParent(); + ++NumPartialInlined; - - return extractedFunction; + + return ExtractedFunction; } -PreservedAnalyses PartialInlinerPass::run(Module &M, ModuleAnalysisManager &) { - std::vector<Function*> worklist; - worklist.reserve(M.size()); +bool PartialInlinerImpl::run(Module &M) { + std::vector<Function *> Worklist; + Worklist.reserve(M.size()); for (Function &F : M) if (!F.use_empty() && !F.isDeclaration()) - worklist.push_back(&F); - - bool changed = false; - while (!worklist.empty()) { - Function* currFunc = worklist.back(); - worklist.pop_back(); - - if (currFunc->use_empty()) continue; - - bool recursive = false; - for (User *U : currFunc->users()) - if (Instruction* I = dyn_cast<Instruction>(U)) - if (I->getParent()->getParent() == currFunc) { - recursive = true; + Worklist.push_back(&F); + + bool Changed = false; + while (!Worklist.empty()) { + Function *CurrFunc = Worklist.back(); + Worklist.pop_back(); + + if (CurrFunc->use_empty()) + continue; + + bool Recursive = false; + for (User *U : CurrFunc->users()) + if (Instruction *I = dyn_cast<Instruction>(U)) + if (I->getParent()->getParent() == CurrFunc) { + Recursive = true; break; } - if (recursive) continue; - - - if (Function* newFunc = unswitchFunction(currFunc)) { - worklist.push_back(newFunc); - changed = true; + if (Recursive) + continue; + + if (Function *NewFunc = unswitchFunction(CurrFunc)) { + Worklist.push_back(NewFunc); + Changed = true; } - } - if (changed) + return Changed; +} + +char PartialInlinerLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(PartialInlinerLegacyPass, "partial-inliner", + "Partial Inliner", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_END(PartialInlinerLegacyPass, "partial-inliner", + "Partial Inliner", false, false) + +ModulePass *llvm::createPartialInliningPass() { + return new PartialInlinerLegacyPass(); +} + +PreservedAnalyses PartialInlinerPass::run(Module &M, + ModuleAnalysisManager &AM) { + auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + std::function<AssumptionCache &(Function &)> GetAssumptionCache = + [&FAM](Function &F) -> AssumptionCache & { + return FAM.getResult<AssumptionAnalysis>(F); + }; + InlineFunctionInfo IFI(nullptr, &GetAssumptionCache); + if (PartialInlinerImpl(IFI).run(M)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); } diff --git a/contrib/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp b/contrib/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp index df6a48e..941efb2 100644 --- a/contrib/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/contrib/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -19,6 +19,7 @@ #include "llvm/Analysis/CFLAndersAliasAnalysis.h" #include "llvm/Analysis/CFLSteensAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/Passes.h" #include "llvm/Analysis/ScopedNoAliasAA.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -66,14 +67,13 @@ static cl::opt<bool> RunLoopRerolling("reroll-loops", cl::Hidden, cl::desc("Run the loop rerolling pass")); -static cl::opt<bool> -RunFloat2Int("float-to-int", cl::Hidden, cl::init(true), - cl::desc("Run the float2int (float demotion) pass")); - static cl::opt<bool> RunLoadCombine("combine-loads", cl::init(false), cl::Hidden, cl::desc("Run the load combining pass")); +static cl::opt<bool> RunNewGVN("enable-newgvn", cl::init(false), cl::Hidden, + cl::desc("Run the NewGVN pass")); + static cl::opt<bool> RunSLPAfterLoopVectorization("run-slp-after-loop-vectorization", cl::init(true), cl::Hidden, @@ -91,8 +91,7 @@ static cl::opt<CFLAAType> clEnumValN(CFLAAType::Andersen, "anders", "Enable inclusion-based CFL-AA"), clEnumValN(CFLAAType::Both, "both", - "Enable both variants of CFL-aa"), - clEnumValEnd)); + "Enable both variants of CFL-AA"))); static cl::opt<bool> EnableMLSM("mlsm", cl::init(true), cl::Hidden, @@ -111,10 +110,17 @@ static cl::opt<bool> EnableLoopLoadElim( "enable-loop-load-elim", cl::init(true), cl::Hidden, cl::desc("Enable the LoopLoadElimination Pass")); -static cl::opt<std::string> RunPGOInstrGen( - "profile-generate", cl::init(""), cl::Hidden, - cl::desc("Enable generation phase of PGO instrumentation and specify the " - "path of profile data file")); +static cl::opt<bool> + EnablePrepareForThinLTO("prepare-for-thinlto", cl::init(false), cl::Hidden, + cl::desc("Enable preparation for ThinLTO.")); + +static cl::opt<bool> RunPGOInstrGen( + "profile-generate", cl::init(false), cl::Hidden, + cl::desc("Enable PGO instrumentation.")); + +static cl::opt<std::string> + PGOOutputFile("profile-generate-file", cl::init(""), cl::Hidden, + cl::desc("Specify the path of profile data file.")); static cl::opt<std::string> RunPGOInstrUse( "profile-use", cl::init(""), cl::Hidden, cl::value_desc("filename"), @@ -136,14 +142,18 @@ static cl::opt<int> PreInlineThreshold( static cl::opt<bool> EnableGVNHoist( "enable-gvn-hoist", cl::init(false), cl::Hidden, - cl::desc("Enable the experimental GVN Hoisting pass")); + cl::desc("Enable the GVN hoisting pass")); + +static cl::opt<bool> + DisableLibCallsShrinkWrap("disable-libcalls-shrinkwrap", cl::init(false), + cl::Hidden, + cl::desc("Disable shrink-wrap library calls")); PassManagerBuilder::PassManagerBuilder() { OptLevel = 2; SizeLevel = 0; LibraryInfo = nullptr; Inliner = nullptr; - ModuleSummary = nullptr; DisableUnitAtATime = false; DisableUnrollLoops = false; BBVectorize = RunBBVectorization; @@ -151,14 +161,16 @@ PassManagerBuilder::PassManagerBuilder() { LoopVectorize = RunLoopVectorization; RerollLoops = RunLoopRerolling; LoadCombine = RunLoadCombine; + NewGVN = RunNewGVN; DisableGVNLoadPRE = false; VerifyInput = false; VerifyOutput = false; MergeFunctions = false; PrepareForLTO = false; - PGOInstrGen = RunPGOInstrGen; + EnablePGOInstrGen = RunPGOInstrGen; + PGOInstrGen = PGOOutputFile; PGOInstrUse = RunPGOInstrUse; - PrepareForThinLTO = false; + PrepareForThinLTO = EnablePrepareForThinLTO; PerformThinLTO = false; } @@ -243,24 +255,34 @@ void PassManagerBuilder::populateFunctionPassManager( // Do PGO instrumentation generation or use pass as the option specified. void PassManagerBuilder::addPGOInstrPasses(legacy::PassManagerBase &MPM) { - if (PGOInstrGen.empty() && PGOInstrUse.empty()) + if (!EnablePGOInstrGen && PGOInstrUse.empty()) return; // Perform the preinline and cleanup passes for O1 and above. // And avoid doing them if optimizing for size. if (OptLevel > 0 && SizeLevel == 0 && !DisablePreInliner) { - // Create preinline pass. - MPM.add(createFunctionInliningPass(PreInlineThreshold)); + // Create preinline pass. We construct an InlineParams object and specify + // the threshold here to avoid the command line options of the regular + // inliner to influence pre-inlining. The only fields of InlineParams we + // care about are DefaultThreshold and HintThreshold. + InlineParams IP; + IP.DefaultThreshold = PreInlineThreshold; + // FIXME: The hint threshold has the same value used by the regular inliner. + // This should probably be lowered after performance testing. + IP.HintThreshold = 325; + + MPM.add(createFunctionInliningPass(IP)); MPM.add(createSROAPass()); MPM.add(createEarlyCSEPass()); // Catch trivial redundancies MPM.add(createCFGSimplificationPass()); // Merge & remove BBs MPM.add(createInstructionCombiningPass()); // Combine silly seq's addExtensionsToPM(EP_Peephole, MPM); } - if (!PGOInstrGen.empty()) { + if (EnablePGOInstrGen) { MPM.add(createPGOInstrumentationGenLegacyPass()); // Add the profile lowering pass. InstrProfOptions Options; - Options.InstrProfileOutput = PGOInstrGen; + if (!PGOInstrGen.empty()) + Options.InstrProfileOutput = PGOInstrGen; MPM.add(createInstrProfilingLegacyPass(Options)); } if (!PGOInstrUse.empty()) @@ -279,6 +301,8 @@ void PassManagerBuilder::addFunctionSimplificationPasses( MPM.add(createCFGSimplificationPass()); // Merge & remove BBs // Combine silly seq's addInstructionCombiningPass(MPM); + if (SizeLevel == 0 && !DisableLibCallsShrinkWrap) + MPM.add(createLibCallsShrinkWrapPass()); addExtensionsToPM(EP_Peephole, MPM); MPM.add(createTailCallEliminationPass()); // Eliminate tail calls @@ -304,7 +328,8 @@ void PassManagerBuilder::addFunctionSimplificationPasses( if (OptLevel > 1) { if (EnableMLSM) MPM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds - MPM.add(createGVNPass(DisableGVNLoadPRE)); // Remove redundancies + MPM.add(NewGVN ? createNewGVNPass() + : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies } MPM.add(createMemCpyOptPass()); // Remove memcpy / form memset MPM.add(createSCCPPass()); // Constant prop with SCCP @@ -336,7 +361,9 @@ void PassManagerBuilder::addFunctionSimplificationPasses( addInstructionCombiningPass(MPM); addExtensionsToPM(EP_Peephole, MPM); if (OptLevel > 1 && UseGVNAfterVectorization) - MPM.add(createGVNPass(DisableGVNLoadPRE)); // Remove redundancies + MPM.add(NewGVN + ? createNewGVNPass() + : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies else MPM.add(createEarlyCSEPass()); // Catch trivial redundancies @@ -358,6 +385,11 @@ void PassManagerBuilder::addFunctionSimplificationPasses( void PassManagerBuilder::populateModulePassManager( legacy::PassManagerBase &MPM) { + if (!PGOSampleUse.empty()) { + MPM.add(createPruneEHPass()); + MPM.add(createSampleProfileLoaderPass(PGOSampleUse)); + } + // Allow forcing function attributes as a debugging and tuning aid. MPM.add(createForceFunctionAttrsLegacyPass()); @@ -380,6 +412,10 @@ void PassManagerBuilder::populateModulePassManager( else if (!GlobalExtensions->empty() || !Extensions.empty()) MPM.add(createBarrierNoopPass()); + if (PrepareForThinLTO) + // Rename anon globals to be able to export them in the summary. + MPM.add(createNameAnonGlobalPass()); + addExtensionsToPM(EP_EnabledOnOptLevel0, MPM); return; } @@ -390,6 +426,16 @@ void PassManagerBuilder::populateModulePassManager( addInitialAliasAnalysisPasses(MPM); + // For ThinLTO there are two passes of indirect call promotion. The + // first is during the compile phase when PerformThinLTO=false and + // intra-module indirect call targets are promoted. The second is during + // the ThinLTO backend when PerformThinLTO=true, when we promote imported + // inter-module indirect calls. For that we perform indirect call promotion + // earlier in the pass pipeline, here before globalopt. Otherwise imported + // available_externally functions look unreferenced and are removed. + if (PerformThinLTO) + MPM.add(createPGOIndirectCallPromotionLegacyPass(/*InLTO = */ true)); + if (!DisableUnitAtATime) { // Infer attributes about declarations if possible. MPM.add(createInferFunctionAttrsLegacyPass()); @@ -412,11 +458,12 @@ void PassManagerBuilder::populateModulePassManager( /// PGO instrumentation is added during the compile phase for ThinLTO, do /// not run it a second time addPGOInstrPasses(MPM); + // Indirect call promotion that promotes intra-module targets only. + // For ThinLTO this is done earlier due to interactions with globalopt + // for imported functions. + MPM.add(createPGOIndirectCallPromotionLegacyPass()); } - // Indirect call promotion that promotes intra-module targets only. - MPM.add(createPGOIndirectCallPromotionLegacyPass()); - if (EnableNonLTOGlobalsModRef) // We add a module alias analysis pass here. In part due to bugs in the // analysis infrastructure this "works" in that the analysis stays alive @@ -435,6 +482,7 @@ void PassManagerBuilder::populateModulePassManager( if (OptLevel > 2) MPM.add(createArgumentPromotionPass()); // Scalarize uninlined fn args + addExtensionsToPM(EP_CGSCCOptimizerLate, MPM); addFunctionSimplificationPasses(MPM); // FIXME: This is a HACK! The inliner pass above implicitly creates a CGSCC @@ -464,8 +512,8 @@ void PassManagerBuilder::populateModulePassManager( if (PrepareForThinLTO) { // Reduce the size of the IR as much as possible. MPM.add(createGlobalOptimizerPass()); - // Rename anon function to be able to export them in the summary. - MPM.add(createNameAnonFunctionPass()); + // Rename anon globals to be able to export them in the summary. + MPM.add(createNameAnonGlobalPass()); return; } @@ -502,8 +550,7 @@ void PassManagerBuilder::populateModulePassManager( // correct in the face of IR changes). MPM.add(createGlobalsAAWrapperPass()); - if (RunFloat2Int) - MPM.add(createFloat2IntPass()); + MPM.add(createFloat2IntPass()); addExtensionsToPM(EP_VectorizerStart, MPM); @@ -516,7 +563,7 @@ void PassManagerBuilder::populateModulePassManager( // into separate loop that would otherwise inhibit vectorization. This is // currently only performed for loops marked with the metadata // llvm.loop.distribute=true or when -enable-loop-distribute is specified. - MPM.add(createLoopDistributePass(/*ProcessAllLoopsByDefault=*/false)); + MPM.add(createLoopDistributePass()); MPM.add(createLoopVectorizePass(DisableUnrollLoops, LoopVectorize)); @@ -560,7 +607,9 @@ void PassManagerBuilder::populateModulePassManager( addInstructionCombiningPass(MPM); addExtensionsToPM(EP_Peephole, MPM); if (OptLevel > 1 && UseGVNAfterVectorization) - MPM.add(createGVNPass(DisableGVNLoadPRE)); // Remove redundancies + MPM.add(NewGVN + ? createNewGVNPass() + : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies else MPM.add(createEarlyCSEPass()); // Catch trivial redundancies @@ -585,10 +634,7 @@ void PassManagerBuilder::populateModulePassManager( // outer loop. LICM pass can help to promote the runtime check out if the // checked value is loop invariant. MPM.add(createLICMPass()); - - // Get rid of LCSSA nodes. - MPM.add(createInstructionSimplifierPass()); - } + } // After vectorization and unrolling, assume intrinsics may tell us more // about pointer alignments. @@ -609,6 +655,13 @@ void PassManagerBuilder::populateModulePassManager( if (MergeFunctions) MPM.add(createMergeFunctionsPass()); + // LoopSink pass sinks instructions hoisted by LICM, which serves as a + // canonicalization pass that enables other optimizations. As a result, + // LoopSink pass needs to be a very late IR pass to avoid undoing LICM + // result too early. + MPM.add(createLoopSinkPass()); + // Get rid of LCSSA nodes. + MPM.add(createInstructionSimplifierPass()); addExtensionsToPM(EP_OptimizerLast, MPM); } @@ -620,9 +673,6 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { // Provide AliasAnalysis services for optimizations. addInitialAliasAnalysisPasses(PM); - if (ModuleSummary) - PM.add(createFunctionImportPass(ModuleSummary)); - // Allow forcing function attributes as a debugging and tuning aid. PM.add(createForceFunctionAttrsLegacyPass()); @@ -647,6 +697,11 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { PM.add(createPostOrderFunctionAttrsLegacyPass()); PM.add(createReversePostOrderFunctionAttrsPass()); + // Split globals using inrange annotations on GEP indices. This can help + // improve the quality of generated code when virtual constant propagation or + // control flow integrity are enabled. + PM.add(createGlobalSplitPass()); + // Apply whole-program devirtualization and virtual constant propagation. PM.add(createWholeProgramDevirtPass()); @@ -706,7 +761,8 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { PM.add(createLICMPass()); // Hoist loop invariants. if (EnableMLSM) PM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds. - PM.add(createGVNPass(DisableGVNLoadPRE)); // Remove redundancies. + PM.add(NewGVN ? createNewGVNPass() + : createGVNPass(DisableGVNLoadPRE)); // Remove redundancies. PM.add(createMemCpyOptPass()); // Remove dead memcpys. // Nuke dead stores. @@ -777,9 +833,6 @@ void PassManagerBuilder::populateThinLTOPassManager( if (VerifyInput) PM.add(createVerifierPass()); - if (ModuleSummary) - PM.add(createFunctionImportPass(ModuleSummary)); - populateModulePassManager(PM); if (VerifyOutput) @@ -804,7 +857,8 @@ void PassManagerBuilder::populateLTOPassManager(legacy::PassManagerBase &PM) { // Lower type metadata and the type.test intrinsic. This pass supports Clang's // control flow integrity mechanisms (-fsanitize=cfi*) and needs to run at // link time if CFI is enabled. The pass does nothing if CFI is disabled. - PM.add(createLowerTypeTestsPass()); + PM.add(createLowerTypeTestsPass(LowerTypeTestsSummaryAction::None, + /*Summary=*/nullptr)); if (OptLevel != 0) addLateLTOOptimizationPasses(PM); diff --git a/contrib/llvm/lib/Transforms/IPO/PruneEH.cpp b/contrib/llvm/lib/Transforms/IPO/PruneEH.cpp index 2aa3fa5..d9acb9b 100644 --- a/contrib/llvm/lib/Transforms/IPO/PruneEH.cpp +++ b/contrib/llvm/lib/Transforms/IPO/PruneEH.cpp @@ -90,10 +90,7 @@ static bool runImpl(CallGraphSCC &SCC, CallGraph &CG) { if (!F) { SCCMightUnwind = true; SCCMightReturn = true; - } else if (F->isDeclaration() || F->isInterposable()) { - // Note: isInterposable (as opposed to hasExactDefinition) is fine above, - // since we're not inferring new attributes here, but only using existing, - // assumed to be correct, function attributes. + } else if (!F->hasExactDefinition()) { SCCMightUnwind |= !F->doesNotThrow(); SCCMightReturn |= !F->doesNotReturn(); } else { diff --git a/contrib/llvm/lib/Transforms/IPO/SampleProfile.cpp b/contrib/llvm/lib/Transforms/IPO/SampleProfile.cpp index 39de108..6a43f8d 100644 --- a/contrib/llvm/lib/Transforms/IPO/SampleProfile.cpp +++ b/contrib/llvm/lib/Transforms/IPO/SampleProfile.cpp @@ -88,6 +88,52 @@ typedef DenseMap<Edge, uint64_t> EdgeWeightMap; typedef DenseMap<const BasicBlock *, SmallVector<const BasicBlock *, 8>> BlockEdgeMap; +class SampleCoverageTracker { +public: + SampleCoverageTracker() : SampleCoverage(), TotalUsedSamples(0) {} + + bool markSamplesUsed(const FunctionSamples *FS, uint32_t LineOffset, + uint32_t Discriminator, uint64_t Samples); + unsigned computeCoverage(unsigned Used, unsigned Total) const; + unsigned countUsedRecords(const FunctionSamples *FS) const; + unsigned countBodyRecords(const FunctionSamples *FS) const; + uint64_t getTotalUsedSamples() const { return TotalUsedSamples; } + uint64_t countBodySamples(const FunctionSamples *FS) const; + void clear() { + SampleCoverage.clear(); + TotalUsedSamples = 0; + } + +private: + typedef std::map<LineLocation, unsigned> BodySampleCoverageMap; + typedef DenseMap<const FunctionSamples *, BodySampleCoverageMap> + FunctionSamplesCoverageMap; + + /// Coverage map for sampling records. + /// + /// This map keeps a record of sampling records that have been matched to + /// an IR instruction. This is used to detect some form of staleness in + /// profiles (see flag -sample-profile-check-coverage). + /// + /// Each entry in the map corresponds to a FunctionSamples instance. This is + /// another map that counts how many times the sample record at the + /// given location has been used. + FunctionSamplesCoverageMap SampleCoverage; + + /// Number of samples used from the profile. + /// + /// When a sampling record is used for the first time, the samples from + /// that record are added to this accumulator. Coverage is later computed + /// based on the total number of samples available in this function and + /// its callsites. + /// + /// Note that this accumulator tracks samples used from a single function + /// and all the inlined callsites. Strictly, we should have a map of counters + /// keyed by FunctionSamples pointers, but these stats are cleared after + /// every function, so we just need to keep a single counter. + uint64_t TotalUsedSamples; +}; + /// \brief Sample profile pass. /// /// This pass reads profile data from the file specified by @@ -110,9 +156,9 @@ protected: bool runOnFunction(Function &F); unsigned getFunctionLoc(Function &F); bool emitAnnotations(Function &F); - ErrorOr<uint64_t> getInstWeight(const Instruction &I) const; - ErrorOr<uint64_t> getBlockWeight(const BasicBlock *BB) const; - const FunctionSamples *findCalleeFunctionSamples(const CallInst &I) const; + ErrorOr<uint64_t> getInstWeight(const Instruction &I); + ErrorOr<uint64_t> getBlockWeight(const BasicBlock *BB); + const FunctionSamples *findCalleeFunctionSamples(const Instruction &I) const; const FunctionSamples *findFunctionSamples(const Instruction &I) const; bool inlineHotFunctions(Function &F); void printEdgeWeight(raw_ostream &OS, Edge E); @@ -125,7 +171,7 @@ protected: void propagateWeights(Function &F); uint64_t visitEdge(Edge E, unsigned *NumUnknownEdges, Edge *UnknownEdge); void buildEdges(Function &F); - bool propagateThroughEdges(Function &F); + bool propagateThroughEdges(Function &F, bool UpdateBlockCount); void computeDominanceAndLoopInfo(Function &F); unsigned getOffset(unsigned L, unsigned H) const; void clearFunctionData(); @@ -169,6 +215,8 @@ protected: /// \brief Successors for each basic block in the CFG. BlockEdgeMap Successors; + SampleCoverageTracker CoverageTracker; + /// \brief Profile reader object. std::unique_ptr<SampleProfileReader> Reader; @@ -176,7 +224,7 @@ protected: FunctionSamples *Samples; /// \brief Name of the profile file to load. - StringRef Filename; + std::string Filename; /// \brief Flag indicating whether the profile input loaded successfully. bool ProfileIsValid; @@ -204,64 +252,17 @@ public: bool doInitialization(Module &M) override { return SampleLoader.doInitialization(M); } - const char *getPassName() const override { return "Sample profile pass"; } + StringRef getPassName() const override { return "Sample profile pass"; } bool runOnModule(Module &M) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); } -private: - SampleProfileLoader SampleLoader; -}; - -class SampleCoverageTracker { -public: - SampleCoverageTracker() : SampleCoverage(), TotalUsedSamples(0) {} - - bool markSamplesUsed(const FunctionSamples *FS, uint32_t LineOffset, - uint32_t Discriminator, uint64_t Samples); - unsigned computeCoverage(unsigned Used, unsigned Total) const; - unsigned countUsedRecords(const FunctionSamples *FS) const; - unsigned countBodyRecords(const FunctionSamples *FS) const; - uint64_t getTotalUsedSamples() const { return TotalUsedSamples; } - uint64_t countBodySamples(const FunctionSamples *FS) const; - void clear() { - SampleCoverage.clear(); - TotalUsedSamples = 0; - } private: - typedef std::map<LineLocation, unsigned> BodySampleCoverageMap; - typedef DenseMap<const FunctionSamples *, BodySampleCoverageMap> - FunctionSamplesCoverageMap; - - /// Coverage map for sampling records. - /// - /// This map keeps a record of sampling records that have been matched to - /// an IR instruction. This is used to detect some form of staleness in - /// profiles (see flag -sample-profile-check-coverage). - /// - /// Each entry in the map corresponds to a FunctionSamples instance. This is - /// another map that counts how many times the sample record at the - /// given location has been used. - FunctionSamplesCoverageMap SampleCoverage; - - /// Number of samples used from the profile. - /// - /// When a sampling record is used for the first time, the samples from - /// that record are added to this accumulator. Coverage is later computed - /// based on the total number of samples available in this function and - /// its callsites. - /// - /// Note that this accumulator tracks samples used from a single function - /// and all the inlined callsites. Strictly, we should have a map of counters - /// keyed by FunctionSamples pointers, but these stats are cleared after - /// every function, so we just need to keep a single counter. - uint64_t TotalUsedSamples; + SampleProfileLoader SampleLoader; }; -SampleCoverageTracker CoverageTracker; - /// Return true if the given callsite is hot wrt to its caller. /// /// Functions that were inlined in the original binary will be represented @@ -451,7 +452,7 @@ void SampleProfileLoader::printBlockWeight(raw_ostream &OS, /// /// \returns the weight of \p Inst. ErrorOr<uint64_t> -SampleProfileLoader::getInstWeight(const Instruction &Inst) const { +SampleProfileLoader::getInstWeight(const Instruction &Inst) { const DebugLoc &DLoc = Inst.getDebugLoc(); if (!DLoc) return std::error_code(); @@ -460,18 +461,28 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) const { if (!FS) return std::error_code(); - // Ignore all dbg_value intrinsics. - const IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst); - if (II && II->getIntrinsicID() == Intrinsic::dbg_value) + // Ignore all intrinsics and branch instructions. + // Branch instruction usually contains debug info from sources outside of + // the residing basic block, thus we ignore them during annotation. + if (isa<BranchInst>(Inst) || isa<IntrinsicInst>(Inst)) return std::error_code(); + // If a call/invoke instruction is inlined in profile, but not inlined here, + // it means that the inlined callsite has no sample, thus the call + // instruction should have 0 count. + bool IsCall = isa<CallInst>(Inst) || isa<InvokeInst>(Inst); + if (IsCall && findCalleeFunctionSamples(Inst)) + return 0; + const DILocation *DIL = DLoc; unsigned Lineno = DLoc.getLine(); unsigned HeaderLineno = DIL->getScope()->getSubprogram()->getLine(); uint32_t LineOffset = getOffset(Lineno, HeaderLineno); uint32_t Discriminator = DIL->getDiscriminator(); - ErrorOr<uint64_t> R = FS->findSamplesAt(LineOffset, Discriminator); + ErrorOr<uint64_t> R = IsCall + ? FS->findCallSamplesAt(LineOffset, Discriminator) + : FS->findSamplesAt(LineOffset, Discriminator); if (R) { bool FirstMark = CoverageTracker.markSamplesUsed(FS, LineOffset, Discriminator, R.get()); @@ -488,13 +499,6 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) const { << Inst << " (line offset: " << Lineno - HeaderLineno << "." << DIL->getDiscriminator() << " - weight: " << R.get() << ")\n"); - } else { - // If a call instruction is inlined in profile, but not inlined here, - // it means that the inlined callsite has no sample, thus the call - // instruction should have 0 count. - const CallInst *CI = dyn_cast<CallInst>(&Inst); - if (CI && findCalleeFunctionSamples(*CI)) - R = 0; } return R; } @@ -508,23 +512,17 @@ SampleProfileLoader::getInstWeight(const Instruction &Inst) const { /// /// \returns the weight for \p BB. ErrorOr<uint64_t> -SampleProfileLoader::getBlockWeight(const BasicBlock *BB) const { - DenseMap<uint64_t, uint64_t> CM; +SampleProfileLoader::getBlockWeight(const BasicBlock *BB) { + uint64_t Max = 0; + bool HasWeight = false; for (auto &I : BB->getInstList()) { const ErrorOr<uint64_t> &R = getInstWeight(I); - if (R) CM[R.get()]++; - } - if (CM.size() == 0) return std::error_code(); - uint64_t W = 0, C = 0; - for (const auto &C_W : CM) { - if (C_W.second == W) { - C = std::max(C, C_W.first); - } else if (C_W.second > W) { - C = C_W.first; - W = C_W.second; + if (R) { + Max = std::max(Max, R.get()); + HasWeight = true; } } - return C; + return HasWeight ? ErrorOr<uint64_t>(Max) : std::error_code(); } /// \brief Compute and store the weights of every basic block. @@ -551,18 +549,18 @@ bool SampleProfileLoader::computeBlockWeights(Function &F) { /// \brief Get the FunctionSamples for a call instruction. /// -/// The FunctionSamples of a call instruction \p Inst is the inlined +/// The FunctionSamples of a call/invoke instruction \p Inst is the inlined /// instance in which that call instruction is calling to. It contains /// all samples that resides in the inlined instance. We first find the /// inlined instance in which the call instruction is from, then we /// traverse its children to find the callsite with the matching -/// location and callee function name. +/// location. /// -/// \param Inst Call instruction to query. +/// \param Inst Call/Invoke instruction to query. /// /// \returns The FunctionSamples pointer to the inlined instance. const FunctionSamples * -SampleProfileLoader::findCalleeFunctionSamples(const CallInst &Inst) const { +SampleProfileLoader::findCalleeFunctionSamples(const Instruction &Inst) const { const DILocation *DIL = Inst.getDebugLoc(); if (!DIL) { return nullptr; @@ -611,7 +609,6 @@ SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { return FS; } - /// \brief Iteratively inline hot callsites of a function. /// /// Iteratively traverse all callsites of the function \p F, and find if @@ -627,22 +624,36 @@ SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { bool SampleProfileLoader::inlineHotFunctions(Function &F) { bool Changed = false; LLVMContext &Ctx = F.getContext(); + std::function<AssumptionCache &(Function &)> GetAssumptionCache = [&]( + Function &F) -> AssumptionCache & { return ACT->getAssumptionCache(F); }; while (true) { bool LocalChanged = false; - SmallVector<CallInst *, 10> CIS; + SmallVector<Instruction *, 10> CIS; for (auto &BB : F) { + bool Hot = false; + SmallVector<Instruction *, 10> Candidates; for (auto &I : BB.getInstList()) { - CallInst *CI = dyn_cast<CallInst>(&I); - if (CI && callsiteIsHot(Samples, findCalleeFunctionSamples(*CI))) - CIS.push_back(CI); + const FunctionSamples *FS = nullptr; + if ((isa<CallInst>(I) || isa<InvokeInst>(I)) && + (FS = findCalleeFunctionSamples(I))) { + Candidates.push_back(&I); + if (callsiteIsHot(Samples, FS)) + Hot = true; + } + } + if (Hot) { + CIS.insert(CIS.begin(), Candidates.begin(), Candidates.end()); } } - for (auto CI : CIS) { - InlineFunctionInfo IFI(nullptr, ACT); - Function *CalledFunction = CI->getCalledFunction(); - DebugLoc DLoc = CI->getDebugLoc(); - uint64_t NumSamples = findCalleeFunctionSamples(*CI)->getTotalSamples(); - if (InlineFunction(CI, IFI)) { + for (auto I : CIS) { + InlineFunctionInfo IFI(nullptr, ACT ? &GetAssumptionCache : nullptr); + CallSite CS(I); + Function *CalledFunction = CS.getCalledFunction(); + if (!CalledFunction || !CalledFunction->getSubprogram()) + continue; + DebugLoc DLoc = I->getDebugLoc(); + uint64_t NumSamples = findCalleeFunctionSamples(*I)->getTotalSamples(); + if (InlineFunction(CS, IFI)) { LocalChanged = true; emitOptimizationRemark(Ctx, DEBUG_TYPE, F, DLoc, Twine("inlined hot callee '") + @@ -693,6 +704,10 @@ void SampleProfileLoader::findEquivalencesFor( bool IsInSameLoop = LI->getLoopFor(BB1) == LI->getLoopFor(BB2); if (BB1 != BB2 && IsDomParent && IsInSameLoop) { EquivalenceClass[BB2] = EC; + // If BB2 is visited, then the entire EC should be marked as visited. + if (VisitedBlocks.count(BB2)) { + VisitedBlocks.insert(EC); + } // If BB2 is heavier than BB1, make BB2 have the same weight // as BB1. @@ -705,7 +720,11 @@ void SampleProfileLoader::findEquivalencesFor( Weight = std::max(Weight, BlockWeights[BB2]); } } - BlockWeights[EC] = Weight; + if (EC == &EC->getParent()->getEntryBlock()) { + BlockWeights[EC] = Samples->getHeadSamples() + 1; + } else { + BlockWeights[EC] = Weight; + } } /// \brief Find equivalence classes. @@ -796,9 +815,12 @@ uint64_t SampleProfileLoader::visitEdge(Edge E, unsigned *NumUnknownEdges, /// count of the basic block, if needed. /// /// \param F Function to process. +/// \param UpdateBlockCount Whether we should update basic block counts that +/// has already been annotated. /// /// \returns True if new weights were assigned to edges or blocks. -bool SampleProfileLoader::propagateThroughEdges(Function &F) { +bool SampleProfileLoader::propagateThroughEdges(Function &F, + bool UpdateBlockCount) { bool Changed = false; DEBUG(dbgs() << "\nPropagation through edges\n"); for (const auto &BI : F) { @@ -890,11 +912,35 @@ bool SampleProfileLoader::propagateThroughEdges(Function &F) { EdgeWeights[UnknownEdge] = BBWeight - TotalWeight; else EdgeWeights[UnknownEdge] = 0; + const BasicBlock *OtherEC; + if (i == 0) + OtherEC = EquivalenceClass[UnknownEdge.first]; + else + OtherEC = EquivalenceClass[UnknownEdge.second]; + // Edge weights should never exceed the BB weights it connects. + if (VisitedBlocks.count(OtherEC) && + EdgeWeights[UnknownEdge] > BlockWeights[OtherEC]) + EdgeWeights[UnknownEdge] = BlockWeights[OtherEC]; VisitedEdges.insert(UnknownEdge); Changed = true; DEBUG(dbgs() << "Set weight for edge: "; printEdgeWeight(dbgs(), UnknownEdge)); } + } else if (VisitedBlocks.count(EC) && BlockWeights[EC] == 0) { + // If a block Weights 0, all its in/out edges should weight 0. + if (i == 0) { + for (auto *Pred : Predecessors[BB]) { + Edge E = std::make_pair(Pred, BB); + EdgeWeights[E] = 0; + VisitedEdges.insert(E); + } + } else { + for (auto *Succ : Successors[BB]) { + Edge E = std::make_pair(BB, Succ); + EdgeWeights[E] = 0; + VisitedEdges.insert(E); + } + } } else if (SelfReferentialEdge.first && VisitedBlocks.count(EC)) { uint64_t &BBWeight = BlockWeights[BB]; // We have a self-referential edge and the weight of BB is known. @@ -907,6 +953,11 @@ bool SampleProfileLoader::propagateThroughEdges(Function &F) { DEBUG(dbgs() << "Set self-referential edge weight to: "; printEdgeWeight(dbgs(), SelfReferentialEdge)); } + if (UpdateBlockCount && !VisitedBlocks.count(EC) && TotalWeight > 0) { + BlockWeights[EC] = TotalWeight; + VisitedBlocks.insert(EC); + Changed = true; + } } } @@ -966,7 +1017,21 @@ void SampleProfileLoader::propagateWeights(Function &F) { // Add an entry count to the function using the samples gathered // at the function entry. - F.setEntryCount(Samples->getHeadSamples()); + F.setEntryCount(Samples->getHeadSamples() + 1); + + // If BB weight is larger than its corresponding loop's header BB weight, + // use the BB weight to replace the loop header BB weight. + for (auto &BI : F) { + BasicBlock *BB = &BI; + Loop *L = LI->getLoopFor(BB); + if (!L) { + continue; + } + BasicBlock *Header = L->getHeader(); + if (Header && BlockWeights[BB] > BlockWeights[Header]) { + BlockWeights[Header] = BlockWeights[BB]; + } + } // Before propagation starts, build, for each block, a list of // unique predecessors and successors. This is necessary to handle @@ -977,7 +1042,23 @@ void SampleProfileLoader::propagateWeights(Function &F) { // Propagate until we converge or we go past the iteration limit. while (Changed && I++ < SampleProfileMaxPropagateIterations) { - Changed = propagateThroughEdges(F); + Changed = propagateThroughEdges(F, false); + } + + // The first propagation propagates BB counts from annotated BBs to unknown + // BBs. The 2nd propagation pass resets edges weights, and use all BB weights + // to propagate edge weights. + VisitedEdges.clear(); + Changed = true; + while (Changed && I++ < SampleProfileMaxPropagateIterations) { + Changed = propagateThroughEdges(F, false); + } + + // The 3rd propagation pass allows adjust annotated BB weights that are + // obviously wrong. + Changed = true; + while (Changed && I++ < SampleProfileMaxPropagateIterations) { + Changed = propagateThroughEdges(F, true); } // Generate MD_prof metadata for every branch instruction using the @@ -994,7 +1075,7 @@ void SampleProfileLoader::propagateWeights(Function &F) { if (!dyn_cast<IntrinsicInst>(&I)) { SmallVector<uint32_t, 1> Weights; Weights.push_back(BlockWeights[BB]); - CI->setMetadata(LLVMContext::MD_prof, + CI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); } } @@ -1023,7 +1104,9 @@ void SampleProfileLoader::propagateWeights(Function &F) { DEBUG(dbgs() << " (saturated due to uint32_t overflow)"); Weight = std::numeric_limits<uint32_t>::max(); } - Weights.push_back(static_cast<uint32_t>(Weight)); + // Weight is added by one to avoid propagation errors introduced by + // 0 weights. + Weights.push_back(static_cast<uint32_t>(Weight + 1)); if (Weight != 0) { if (Weight > MaxWeight) { MaxWeight = Weight; @@ -1192,10 +1275,10 @@ bool SampleProfileLoader::emitAnnotations(Function &F) { char SampleProfileLoaderLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(SampleProfileLoaderLegacyPass, "sample-profile", - "Sample Profile loader", false, false) + "Sample Profile loader", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_END(SampleProfileLoaderLegacyPass, "sample-profile", - "Sample Profile loader", false, false) + "Sample Profile loader", false, false) bool SampleProfileLoader::doInitialization(Module &M) { auto &Ctx = M.getContext(); @@ -1232,12 +1315,13 @@ bool SampleProfileLoader::runOnModule(Module &M) { clearFunctionData(); retval |= runOnFunction(F); } - M.setProfileSummary(Reader->getSummary().getMD(M.getContext())); + if (M.getProfileSummary() == nullptr) + M.setProfileSummary(Reader->getSummary().getMD(M.getContext())); return retval; } bool SampleProfileLoaderLegacyPass::runOnModule(Module &M) { - // FIXME: pass in AssumptionCache correctly for the new pass manager. + // FIXME: pass in AssumptionCache correctly for the new pass manager. SampleLoader.setACT(&getAnalysis<AssumptionCacheTracker>()); return SampleLoader.runOnModule(M); } @@ -1251,7 +1335,7 @@ bool SampleProfileLoader::runOnFunction(Function &F) { } PreservedAnalyses SampleProfileLoaderPass::run(Module &M, - AnalysisManager<Module> &AM) { + ModuleAnalysisManager &AM) { SampleProfileLoader SampleLoader(SampleProfileFile); diff --git a/contrib/llvm/lib/Transforms/IPO/StripSymbols.cpp b/contrib/llvm/lib/Transforms/IPO/StripSymbols.cpp index fd25036..8f6f161 100644 --- a/contrib/llvm/lib/Transforms/IPO/StripSymbols.cpp +++ b/contrib/llvm/lib/Transforms/IPO/StripSymbols.cpp @@ -219,7 +219,8 @@ static bool StripSymbolNames(Module &M, bool PreserveDbgInfo) { if (I.hasLocalLinkage() && llvmUsedValues.count(&I) == 0) if (!PreserveDbgInfo || !I.getName().startswith("llvm.dbg")) I.setName(""); // Internal symbols can't participate in linkage - StripSymtab(I.getValueSymbolTable(), PreserveDbgInfo); + if (auto *Symtab = I.getValueSymbolTable()) + StripSymtab(*Symtab, PreserveDbgInfo); } // Remove all names from types. @@ -312,26 +313,29 @@ bool StripDeadDebugInfo::runOnModule(Module &M) { // replace the current list of potentially dead global variables/functions // with the live list. SmallVector<Metadata *, 64> LiveGlobalVariables; - SmallVector<Metadata *, 64> LiveSubprograms; - DenseSet<const MDNode *> VisitedSet; - - std::set<DISubprogram *> LiveSPs; - for (Function &F : M) { - if (DISubprogram *SP = F.getSubprogram()) - LiveSPs.insert(SP); + DenseSet<DIGlobalVariableExpression *> VisitedSet; + + std::set<DIGlobalVariableExpression *> LiveGVs; + for (GlobalVariable &GV : M.globals()) { + SmallVector<DIGlobalVariableExpression *, 1> GVEs; + GV.getDebugInfo(GVEs); + for (auto *GVE : GVEs) + LiveGVs.insert(GVE); } for (DICompileUnit *DIC : F.compile_units()) { // Create our live global variable list. bool GlobalVariableChange = false; - for (DIGlobalVariable *DIG : DIC->getGlobalVariables()) { + for (auto *DIG : DIC->getGlobalVariables()) { + if (DIG->getExpression() && DIG->getExpression()->isConstant()) + LiveGVs.insert(DIG); + // Make sure we only visit each global variable only once. if (!VisitedSet.insert(DIG).second) continue; - // If the global variable referenced by DIG is not null, the global - // variable is live. - if (DIG->getVariable()) + // If a global variable references DIG, the global variable is live. + if (LiveGVs.count(DIG)) LiveGlobalVariables.push_back(DIG); else GlobalVariableChange = true; @@ -345,7 +349,6 @@ bool StripDeadDebugInfo::runOnModule(Module &M) { } // Reset lists for the next iteration. - LiveSubprograms.clear(); LiveGlobalVariables.clear(); } diff --git a/contrib/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/contrib/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp new file mode 100644 index 0000000..3680cfc --- /dev/null +++ b/contrib/llvm/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -0,0 +1,344 @@ +//===- ThinLTOBitcodeWriter.cpp - Bitcode writing pass for ThinLTO --------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass prepares a module containing type metadata for ThinLTO by splitting +// it into regular and thin LTO parts if possible, and writing both parts to +// a multi-module bitcode file. Modules that do not contain type metadata are +// written unmodified as a single module. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO.h" +#include "llvm/Analysis/ModuleSummaryAnalysis.h" +#include "llvm/Analysis/TypeMetadataUtils.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/ScopedPrinter.h" +#include "llvm/Transforms/Utils/Cloning.h" +using namespace llvm; + +namespace { + +// Produce a unique identifier for this module by taking the MD5 sum of the +// names of the module's strong external symbols. This identifier is +// normally guaranteed to be unique, or the program would fail to link due to +// multiply defined symbols. +// +// If the module has no strong external symbols (such a module may still have a +// semantic effect if it performs global initialization), we cannot produce a +// unique identifier for this module, so we return the empty string, which +// causes the entire module to be written as a regular LTO module. +std::string getModuleId(Module *M) { + MD5 Md5; + bool ExportsSymbols = false; + auto AddGlobal = [&](GlobalValue &GV) { + if (GV.isDeclaration() || GV.getName().startswith("llvm.") || + !GV.hasExternalLinkage()) + return; + ExportsSymbols = true; + Md5.update(GV.getName()); + Md5.update(ArrayRef<uint8_t>{0}); + }; + + for (auto &F : *M) + AddGlobal(F); + for (auto &GV : M->globals()) + AddGlobal(GV); + for (auto &GA : M->aliases()) + AddGlobal(GA); + for (auto &IF : M->ifuncs()) + AddGlobal(IF); + + if (!ExportsSymbols) + return ""; + + MD5::MD5Result R; + Md5.final(R); + + SmallString<32> Str; + MD5::stringifyResult(R, Str); + return ("$" + Str).str(); +} + +// Promote each local-linkage entity defined by ExportM and used by ImportM by +// changing visibility and appending the given ModuleId. +void promoteInternals(Module &ExportM, Module &ImportM, StringRef ModuleId) { + auto PromoteInternal = [&](GlobalValue &ExportGV) { + if (!ExportGV.hasLocalLinkage()) + return; + + GlobalValue *ImportGV = ImportM.getNamedValue(ExportGV.getName()); + if (!ImportGV || ImportGV->use_empty()) + return; + + std::string NewName = (ExportGV.getName() + ModuleId).str(); + + ExportGV.setName(NewName); + ExportGV.setLinkage(GlobalValue::ExternalLinkage); + ExportGV.setVisibility(GlobalValue::HiddenVisibility); + + ImportGV->setName(NewName); + ImportGV->setVisibility(GlobalValue::HiddenVisibility); + }; + + for (auto &F : ExportM) + PromoteInternal(F); + for (auto &GV : ExportM.globals()) + PromoteInternal(GV); + for (auto &GA : ExportM.aliases()) + PromoteInternal(GA); + for (auto &IF : ExportM.ifuncs()) + PromoteInternal(IF); +} + +// Promote all internal (i.e. distinct) type ids used by the module by replacing +// them with external type ids formed using the module id. +// +// Note that this needs to be done before we clone the module because each clone +// will receive its own set of distinct metadata nodes. +void promoteTypeIds(Module &M, StringRef ModuleId) { + DenseMap<Metadata *, Metadata *> LocalToGlobal; + auto ExternalizeTypeId = [&](CallInst *CI, unsigned ArgNo) { + Metadata *MD = + cast<MetadataAsValue>(CI->getArgOperand(ArgNo))->getMetadata(); + + if (isa<MDNode>(MD) && cast<MDNode>(MD)->isDistinct()) { + Metadata *&GlobalMD = LocalToGlobal[MD]; + if (!GlobalMD) { + std::string NewName = + (to_string(LocalToGlobal.size()) + ModuleId).str(); + GlobalMD = MDString::get(M.getContext(), NewName); + } + + CI->setArgOperand(ArgNo, + MetadataAsValue::get(M.getContext(), GlobalMD)); + } + }; + + if (Function *TypeTestFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_test))) { + for (const Use &U : TypeTestFunc->uses()) { + auto CI = cast<CallInst>(U.getUser()); + ExternalizeTypeId(CI, 1); + } + } + + if (Function *TypeCheckedLoadFunc = + M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load))) { + for (const Use &U : TypeCheckedLoadFunc->uses()) { + auto CI = cast<CallInst>(U.getUser()); + ExternalizeTypeId(CI, 2); + } + } + + for (GlobalObject &GO : M.global_objects()) { + SmallVector<MDNode *, 1> MDs; + GO.getMetadata(LLVMContext::MD_type, MDs); + + GO.eraseMetadata(LLVMContext::MD_type); + for (auto MD : MDs) { + auto I = LocalToGlobal.find(MD->getOperand(1)); + if (I == LocalToGlobal.end()) { + GO.addMetadata(LLVMContext::MD_type, *MD); + continue; + } + GO.addMetadata( + LLVMContext::MD_type, + *MDNode::get(M.getContext(), + ArrayRef<Metadata *>{MD->getOperand(0), I->second})); + } + } +} + +// Drop unused globals, and drop type information from function declarations. +// FIXME: If we made functions typeless then there would be no need to do this. +void simplifyExternals(Module &M) { + FunctionType *EmptyFT = + FunctionType::get(Type::getVoidTy(M.getContext()), false); + + for (auto I = M.begin(), E = M.end(); I != E;) { + Function &F = *I++; + if (F.isDeclaration() && F.use_empty()) { + F.eraseFromParent(); + continue; + } + + if (!F.isDeclaration() || F.getFunctionType() == EmptyFT) + continue; + + Function *NewF = + Function::Create(EmptyFT, GlobalValue::ExternalLinkage, "", &M); + NewF->setVisibility(F.getVisibility()); + NewF->takeName(&F); + F.replaceAllUsesWith(ConstantExpr::getBitCast(NewF, F.getType())); + F.eraseFromParent(); + } + + for (auto I = M.global_begin(), E = M.global_end(); I != E;) { + GlobalVariable &GV = *I++; + if (GV.isDeclaration() && GV.use_empty()) { + GV.eraseFromParent(); + continue; + } + } +} + +void filterModule( + Module *M, std::function<bool(const GlobalValue *)> ShouldKeepDefinition) { + for (Function &F : *M) { + if (ShouldKeepDefinition(&F)) + continue; + + F.deleteBody(); + F.clearMetadata(); + } + + for (GlobalVariable &GV : M->globals()) { + if (ShouldKeepDefinition(&GV)) + continue; + + GV.setInitializer(nullptr); + GV.setLinkage(GlobalValue::ExternalLinkage); + GV.clearMetadata(); + } + + for (Module::alias_iterator I = M->alias_begin(), E = M->alias_end(); + I != E;) { + GlobalAlias *GA = &*I++; + if (ShouldKeepDefinition(GA)) + continue; + + GlobalObject *GO; + if (I->getValueType()->isFunctionTy()) + GO = Function::Create(cast<FunctionType>(GA->getValueType()), + GlobalValue::ExternalLinkage, "", M); + else + GO = new GlobalVariable( + *M, GA->getValueType(), false, GlobalValue::ExternalLinkage, + (Constant *)nullptr, "", (GlobalVariable *)nullptr, + GA->getThreadLocalMode(), GA->getType()->getAddressSpace()); + GO->takeName(GA); + GA->replaceAllUsesWith(GO); + GA->eraseFromParent(); + } +} + +// If it's possible to split M into regular and thin LTO parts, do so and write +// a multi-module bitcode file with the two parts to OS. Otherwise, write only a +// regular LTO bitcode file to OS. +void splitAndWriteThinLTOBitcode(raw_ostream &OS, Module &M) { + std::string ModuleId = getModuleId(&M); + if (ModuleId.empty()) { + // We couldn't generate a module ID for this module, just write it out as a + // regular LTO module. + WriteBitcodeToFile(&M, OS); + return; + } + + promoteTypeIds(M, ModuleId); + + auto IsInMergedM = [&](const GlobalValue *GV) { + auto *GVar = dyn_cast<GlobalVariable>(GV->getBaseObject()); + if (!GVar) + return false; + + SmallVector<MDNode *, 1> MDs; + GVar->getMetadata(LLVMContext::MD_type, MDs); + return !MDs.empty(); + }; + + ValueToValueMapTy VMap; + std::unique_ptr<Module> MergedM(CloneModule(&M, VMap, IsInMergedM)); + + filterModule(&M, [&](const GlobalValue *GV) { return !IsInMergedM(GV); }); + + promoteInternals(*MergedM, M, ModuleId); + promoteInternals(M, *MergedM, ModuleId); + + simplifyExternals(*MergedM); + + SmallVector<char, 0> Buffer; + BitcodeWriter W(Buffer); + + // FIXME: Try to re-use BSI and PFI from the original module here. + ModuleSummaryIndex Index = buildModuleSummaryIndex(M, nullptr, nullptr); + W.writeModule(&M, /*ShouldPreserveUseListOrder=*/false, &Index, + /*GenerateHash=*/true); + + W.writeModule(MergedM.get()); + + OS << Buffer; +} + +// Returns whether this module needs to be split because it uses type metadata. +bool requiresSplit(Module &M) { + SmallVector<MDNode *, 1> MDs; + for (auto &GO : M.global_objects()) { + GO.getMetadata(LLVMContext::MD_type, MDs); + if (!MDs.empty()) + return true; + } + + return false; +} + +void writeThinLTOBitcode(raw_ostream &OS, Module &M, + const ModuleSummaryIndex *Index) { + // See if this module has any type metadata. If so, we need to split it. + if (requiresSplit(M)) + return splitAndWriteThinLTOBitcode(OS, M); + + // Otherwise we can just write it out as a regular module. + WriteBitcodeToFile(&M, OS, /*ShouldPreserveUseListOrder=*/false, Index, + /*GenerateHash=*/true); +} + +class WriteThinLTOBitcode : public ModulePass { + raw_ostream &OS; // raw_ostream to print on + +public: + static char ID; // Pass identification, replacement for typeid + WriteThinLTOBitcode() : ModulePass(ID), OS(dbgs()) { + initializeWriteThinLTOBitcodePass(*PassRegistry::getPassRegistry()); + } + + explicit WriteThinLTOBitcode(raw_ostream &o) + : ModulePass(ID), OS(o) { + initializeWriteThinLTOBitcodePass(*PassRegistry::getPassRegistry()); + } + + StringRef getPassName() const override { return "ThinLTO Bitcode Writer"; } + + bool runOnModule(Module &M) override { + const ModuleSummaryIndex *Index = + &(getAnalysis<ModuleSummaryIndexWrapperPass>().getIndex()); + writeThinLTOBitcode(OS, M, Index); + return true; + } + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + AU.addRequired<ModuleSummaryIndexWrapperPass>(); + } +}; +} // anonymous namespace + +char WriteThinLTOBitcode::ID = 0; +INITIALIZE_PASS_BEGIN(WriteThinLTOBitcode, "write-thinlto-bitcode", + "Write ThinLTO Bitcode", false, true) +INITIALIZE_PASS_DEPENDENCY(ModuleSummaryIndexWrapperPass) +INITIALIZE_PASS_END(WriteThinLTOBitcode, "write-thinlto-bitcode", + "Write ThinLTO Bitcode", false, true) + +ModulePass *llvm::createWriteThinLTOBitcodePass(raw_ostream &Str) { + return new WriteThinLTOBitcode(Str); +} diff --git a/contrib/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/contrib/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index 53eb4e2..844cc0f 100644 --- a/contrib/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/contrib/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -29,24 +29,43 @@ #include "llvm/Transforms/IPO/WholeProgramDevirt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" -#include "llvm/Support/raw_ostream.h" +#include "llvm/PassRegistry.h" +#include "llvm/PassSupport.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/Evaluator.h" -#include "llvm/Transforms/Utils/Local.h" - +#include <algorithm> +#include <cstddef> +#include <map> #include <set> +#include <string> using namespace llvm; using namespace wholeprogramdevirt; @@ -166,7 +185,7 @@ void wholeprogramdevirt::setAfterReturnValues( VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM) : Fn(Fn), TM(TM), - IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()) {} + IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), WasDevirt(false) {} namespace { @@ -178,7 +197,7 @@ struct VTableSlot { uint64_t ByteOffset; }; -} +} // end anonymous namespace namespace llvm { @@ -201,7 +220,7 @@ template <> struct DenseMapInfo<VTableSlot> { } }; -} +} // end namespace llvm namespace { @@ -216,15 +235,18 @@ struct VirtualCallSite { // of that field for details. unsigned *NumUnsafeUses; - void emitRemark() { + void emitRemark(const Twine &OptName, const Twine &TargetName) { Function *F = CS.getCaller(); - emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, - CS.getInstruction()->getDebugLoc(), - "devirtualized call"); + emitOptimizationRemark( + F->getContext(), DEBUG_TYPE, *F, + CS.getInstruction()->getDebugLoc(), + OptName + ": devirtualized a call to " + TargetName); } - void replaceAndErase(Value *New) { - emitRemark(); + void replaceAndErase(const Twine &OptName, const Twine &TargetName, + bool RemarksEnabled, Value *New) { + if (RemarksEnabled) + emitRemark(OptName, TargetName); CS->replaceAllUsesWith(New); if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) { BranchInst::Create(II->getNormalDest(), CS.getInstruction()); @@ -243,6 +265,8 @@ struct DevirtModule { PointerType *Int8PtrTy; IntegerType *Int32Ty; + bool RemarksEnabled; + MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots; // This map keeps track of the number of "unsafe" uses of a loaded function @@ -258,7 +282,10 @@ struct DevirtModule { DevirtModule(Module &M) : M(M), Int8Ty(Type::getInt8Ty(M.getContext())), Int8PtrTy(Type::getInt8PtrTy(M.getContext())), - Int32Ty(Type::getInt32Ty(M.getContext())) {} + Int32Ty(Type::getInt32Ty(M.getContext())), + RemarksEnabled(areRemarksEnabled()) {} + + bool areRemarksEnabled(); void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc); void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc); @@ -266,20 +293,21 @@ struct DevirtModule { void buildTypeIdentifierMap( std::vector<VTableBits> &Bits, DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); + Constant *getPointerAtOffset(Constant *I, uint64_t Offset); bool tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset); - bool trySingleImplDevirt(ArrayRef<VirtualCallTarget> TargetsForSlot, + bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, MutableArrayRef<VirtualCallSite> CallSites); bool tryEvaluateFunctionsWithArgs( MutableArrayRef<VirtualCallTarget> TargetsForSlot, ArrayRef<ConstantInt *> Args); bool tryUniformRetValOpt(IntegerType *RetType, - ArrayRef<VirtualCallTarget> TargetsForSlot, + MutableArrayRef<VirtualCallTarget> TargetsForSlot, MutableArrayRef<VirtualCallSite> CallSites); bool tryUniqueRetValOpt(unsigned BitWidth, - ArrayRef<VirtualCallTarget> TargetsForSlot, + MutableArrayRef<VirtualCallTarget> TargetsForSlot, MutableArrayRef<VirtualCallSite> CallSites); bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot, ArrayRef<VirtualCallSite> CallSites); @@ -291,10 +319,12 @@ struct DevirtModule { struct WholeProgramDevirt : public ModulePass { static char ID; + WholeProgramDevirt() : ModulePass(ID) { initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry()); } - bool runOnModule(Module &M) { + + bool runOnModule(Module &M) override { if (skipModule(M)) return false; @@ -302,7 +332,7 @@ struct WholeProgramDevirt : public ModulePass { } }; -} // anonymous namespace +} // end anonymous namespace INITIALIZE_PASS(WholeProgramDevirt, "wholeprogramdevirt", "Whole program devirtualization", false, false) @@ -353,6 +383,38 @@ void DevirtModule::buildTypeIdentifierMap( } } +Constant *DevirtModule::getPointerAtOffset(Constant *I, uint64_t Offset) { + if (I->getType()->isPointerTy()) { + if (Offset == 0) + return I; + return nullptr; + } + + const DataLayout &DL = M.getDataLayout(); + + if (auto *C = dyn_cast<ConstantStruct>(I)) { + const StructLayout *SL = DL.getStructLayout(C->getType()); + if (Offset >= SL->getSizeInBytes()) + return nullptr; + + unsigned Op = SL->getElementContainingOffset(Offset); + return getPointerAtOffset(cast<Constant>(I->getOperand(Op)), + Offset - SL->getElementOffset(Op)); + } + if (auto *C = dyn_cast<ConstantArray>(I)) { + ArrayType *VTableTy = C->getType(); + uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType()); + + unsigned Op = Offset / ElemSize; + if (Op >= C->getNumOperands()) + return nullptr; + + return getPointerAtOffset(cast<Constant>(I->getOperand(Op)), + Offset % ElemSize); + } + return nullptr; +} + bool DevirtModule::tryFindVirtualCallTargets( std::vector<VirtualCallTarget> &TargetsForSlot, const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) { @@ -360,22 +422,12 @@ bool DevirtModule::tryFindVirtualCallTargets( if (!TM.Bits->GV->isConstant()) return false; - auto Init = dyn_cast<ConstantArray>(TM.Bits->GV->getInitializer()); - if (!Init) - return false; - ArrayType *VTableTy = Init->getType(); - - uint64_t ElemSize = - M.getDataLayout().getTypeAllocSize(VTableTy->getElementType()); - uint64_t GlobalSlotOffset = TM.Offset + ByteOffset; - if (GlobalSlotOffset % ElemSize != 0) - return false; - - unsigned Op = GlobalSlotOffset / ElemSize; - if (Op >= Init->getNumOperands()) + Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(), + TM.Offset + ByteOffset); + if (!Ptr) return false; - auto Fn = dyn_cast<Function>(Init->getOperand(Op)->stripPointerCasts()); + auto Fn = dyn_cast<Function>(Ptr->stripPointerCasts()); if (!Fn) return false; @@ -392,7 +444,7 @@ bool DevirtModule::tryFindVirtualCallTargets( } bool DevirtModule::trySingleImplDevirt( - ArrayRef<VirtualCallTarget> TargetsForSlot, + MutableArrayRef<VirtualCallTarget> TargetsForSlot, MutableArrayRef<VirtualCallSite> CallSites) { // See if the program contains a single implementation of this virtual // function. @@ -401,9 +453,12 @@ bool DevirtModule::trySingleImplDevirt( if (TheFn != Target.Fn) return false; + if (RemarksEnabled) + TargetsForSlot[0].WasDevirt = true; // If so, update each call site to call that implementation directly. for (auto &&VCallSite : CallSites) { - VCallSite.emitRemark(); + if (RemarksEnabled) + VCallSite.emitRemark("single-impl", TheFn->getName()); VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( TheFn, VCallSite.CS.getCalledValue()->getType())); // This use is no longer unsafe. @@ -441,7 +496,7 @@ bool DevirtModule::tryEvaluateFunctionsWithArgs( } bool DevirtModule::tryUniformRetValOpt( - IntegerType *RetType, ArrayRef<VirtualCallTarget> TargetsForSlot, + IntegerType *RetType, MutableArrayRef<VirtualCallTarget> TargetsForSlot, MutableArrayRef<VirtualCallSite> CallSites) { // Uniform return value optimization. If all functions return the same // constant, replace all calls with that constant. @@ -452,16 +507,20 @@ bool DevirtModule::tryUniformRetValOpt( auto TheRetValConst = ConstantInt::get(RetType, TheRetVal); for (auto Call : CallSites) - Call.replaceAndErase(TheRetValConst); + Call.replaceAndErase("uniform-ret-val", TargetsForSlot[0].Fn->getName(), + RemarksEnabled, TheRetValConst); + if (RemarksEnabled) + for (auto &&Target : TargetsForSlot) + Target.WasDevirt = true; return true; } bool DevirtModule::tryUniqueRetValOpt( - unsigned BitWidth, ArrayRef<VirtualCallTarget> TargetsForSlot, + unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot, MutableArrayRef<VirtualCallSite> CallSites) { // IsOne controls whether we look for a 0 or a 1. auto tryUniqueRetValOptFor = [&](bool IsOne) { - const TypeMemberInfo *UniqueMember = 0; + const TypeMemberInfo *UniqueMember = nullptr; for (const VirtualCallTarget &Target : TargetsForSlot) { if (Target.RetVal == (IsOne ? 1 : 0)) { if (UniqueMember) @@ -481,8 +540,14 @@ bool DevirtModule::tryUniqueRetValOpt( OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueMember->Offset); Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable, OneAddr); - Call.replaceAndErase(Cmp); + Call.replaceAndErase("unique-ret-val", TargetsForSlot[0].Fn->getName(), + RemarksEnabled, Cmp); } + // Update devirtualization statistics for targets. + if (RemarksEnabled) + for (auto &&Target : TargetsForSlot) + Target.WasDevirt = true; + return true; }; @@ -590,6 +655,10 @@ bool DevirtModule::tryVirtualConstProp( setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte, OffsetBit); + if (RemarksEnabled) + for (auto &&Target : TargetsForSlot) + Target.WasDevirt = true; + // Rewrite each call to a load from OffsetByte/OffsetBit. for (auto Call : CSByConstantArg.second) { IRBuilder<> B(Call.CS.getInstruction()); @@ -599,11 +668,15 @@ bool DevirtModule::tryVirtualConstProp( Value *Bit = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); Value *BitsAndBit = B.CreateAnd(Bits, Bit); auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); - Call.replaceAndErase(IsBitSet); + Call.replaceAndErase("virtual-const-prop-1-bit", + TargetsForSlot[0].Fn->getName(), + RemarksEnabled, IsBitSet); } else { Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); Value *Val = B.CreateLoad(RetType, ValAddr); - Call.replaceAndErase(Val); + Call.replaceAndErase("virtual-const-prop", + TargetsForSlot[0].Fn->getName(), + RemarksEnabled, Val); } } } @@ -655,6 +728,15 @@ void DevirtModule::rebuildGlobal(VTableBits &B) { B.GV->eraseFromParent(); } +bool DevirtModule::areRemarksEnabled() { + const auto &FL = M.getFunctionList(); + if (FL.empty()) + return false; + const Function &Fn = FL.front(); + auto DI = OptimizationRemark(DEBUG_TYPE, Fn, DebugLoc(), ""); + return DI.isEnabled(); +} + void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc) { // Find all virtual calls via a virtual table pointer %p under an assumption @@ -806,6 +888,7 @@ bool DevirtModule::run() { // For each (type, offset) pair: bool DidVirtualConstProp = false; + std::map<std::string, Function*> DevirtTargets; for (auto &S : CallSlots) { // Search each of the members of the type identifier for the virtual // function implementation at offset S.first.ByteOffset, and add to @@ -815,10 +898,26 @@ bool DevirtModule::run() { S.first.ByteOffset)) continue; - if (trySingleImplDevirt(TargetsForSlot, S.second)) - continue; + if (!trySingleImplDevirt(TargetsForSlot, S.second) && + tryVirtualConstProp(TargetsForSlot, S.second)) + DidVirtualConstProp = true; - DidVirtualConstProp |= tryVirtualConstProp(TargetsForSlot, S.second); + // Collect functions devirtualized at least for one call site for stats. + if (RemarksEnabled) + for (const auto &T : TargetsForSlot) + if (T.WasDevirt) + DevirtTargets[T.Fn->getName()] = T.Fn; + } + + if (RemarksEnabled) { + // Generate remarks for each devirtualized function. + for (const auto &DT : DevirtTargets) { + Function *F = DT.second; + DISubprogram *SP = F->getSubprogram(); + DebugLoc DL = SP ? DebugLoc::get(SP->getScopeLine(), 0, SP) : DebugLoc(); + emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, DL, + Twine("devirtualized ") + F->getName()); + } } // If we were able to eliminate all unsafe uses for a type checked load, diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 221a220..2d34c1c 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1035,7 +1035,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) + I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // (A*B)+(A*C) -> A*(B+C) etc @@ -1047,6 +1047,28 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // X + (signbit) --> X ^ signbit if (Val->isSignBit()) return BinaryOperator::CreateXor(LHS, RHS); + + // Is this add the last step in a convoluted sext? + Value *X; + const APInt *C; + if (match(LHS, m_ZExt(m_Xor(m_Value(X), m_APInt(C)))) && + C->isMinSignedValue() && + C->sext(LHS->getType()->getScalarSizeInBits()) == *Val) { + // add(zext(xor i16 X, -32768), -32768) --> sext X + return CastInst::Create(Instruction::SExt, X, LHS->getType()); + } + + if (Val->isNegative() && + match(LHS, m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C)))) && + Val->sge(-C->sext(Val->getBitWidth()))) { + // (add (zext (add nuw X, C)), Val) -> (zext (add nuw X, C+Val)) + return CastInst::Create( + Instruction::ZExt, + Builder->CreateNUWAdd( + X, Constant::getIntegerValue(X->getType(), + *C + Val->trunc(C->getBitWidth()))), + I.getType()); + } } // FIXME: Use the match above instead of dyn_cast to allow these transforms @@ -1144,7 +1166,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return replaceInstUsesWith(I, V); // A+B --> A|B iff A and B have no bits set in common. - if (haveNoCommonBitsSet(LHS, RHS, DL, AC, &I, DT)) + if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) return BinaryOperator::CreateOr(LHS, RHS); if (Constant *CRHS = dyn_cast<Constant>(RHS)) { @@ -1216,15 +1238,16 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (SExtInst *LHSConv = dyn_cast<SExtInst>(LHS)) { // (add (sext x), cst) --> (sext (add x, cst')) if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS)) { - Constant *CI = - ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); - if (LHSConv->hasOneUse() && - ConstantExpr::getSExt(CI, I.getType()) == RHSC && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) { - // Insert the new, smaller add. - Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), - CI, "addconv"); - return new SExtInst(NewAdd, I.getType()); + if (LHSConv->hasOneUse()) { + Constant *CI = + ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); + if (ConstantExpr::getSExt(CI, I.getType()) == RHSC && + WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) { + // Insert the new, smaller add. + Value *NewAdd = + Builder->CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv"); + return new SExtInst(NewAdd, I.getType()); + } } } @@ -1246,6 +1269,44 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } } + // Check for (add (zext x), y), see if we can merge this into an + // integer add followed by a zext. + if (auto *LHSConv = dyn_cast<ZExtInst>(LHS)) { + // (add (zext x), cst) --> (zext (add x, cst')) + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS)) { + if (LHSConv->hasOneUse()) { + Constant *CI = + ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); + if (ConstantExpr::getZExt(CI, I.getType()) == RHSC && + computeOverflowForUnsignedAdd(LHSConv->getOperand(0), CI, &I) == + OverflowResult::NeverOverflows) { + // Insert the new, smaller add. + Value *NewAdd = + Builder->CreateNUWAdd(LHSConv->getOperand(0), CI, "addconv"); + return new ZExtInst(NewAdd, I.getType()); + } + } + } + + // (add (zext x), (zext y)) --> (zext (add int x, y)) + if (auto *RHSConv = dyn_cast<ZExtInst>(RHS)) { + // Only do this if x/y have the same type, if at last one of them has a + // single use (so we don't increase the number of zexts), and if the + // integer add will not overflow. + if (LHSConv->getOperand(0)->getType() == + RHSConv->getOperand(0)->getType() && + (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && + computeOverflowForUnsignedAdd(LHSConv->getOperand(0), + RHSConv->getOperand(0), + &I) == OverflowResult::NeverOverflows) { + // Insert the new integer add. + Value *NewAdd = Builder->CreateNUWAdd( + LHSConv->getOperand(0), RHSConv->getOperand(0), "addconv"); + return new ZExtInst(NewAdd, I.getType()); + } + } + } + // (add (xor A, B) (and A, B)) --> (or A, B) { Value *A = nullptr, *B = nullptr; @@ -1307,18 +1368,12 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = - SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL, TLI, DT, AC)) + SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); - if (isa<Constant>(RHS)) { - if (isa<PHINode>(LHS)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; - - if (SelectInst *SI = dyn_cast<SelectInst>(LHS)) - if (Instruction *NV = FoldOpIntoSelect(I, SI)) - return NV; - } + if (isa<Constant>(RHS)) + if (Instruction *FoldedFAdd = foldOpWithConstantIntoOperand(I)) + return FoldedFAdd; // -A + B --> B - A // -A + -B --> -(A + B) @@ -1483,7 +1538,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) + I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // (A*B)-(A*C) -> A*(B-C) etc @@ -1544,34 +1599,35 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return CastInst::CreateZExtOrBitCast(X, Op1->getType()); } - if (ConstantInt *C = dyn_cast<ConstantInt>(Op0)) { + const APInt *Op0C; + if (match(Op0, m_APInt(Op0C))) { + unsigned BitWidth = I.getType()->getScalarSizeInBits(); + // -(X >>u 31) -> (X >>s 31) // -(X >>s 31) -> (X >>u 31) - if (C->isZero()) { + if (*Op0C == 0) { Value *X; - ConstantInt *CI; - if (match(Op1, m_LShr(m_Value(X), m_ConstantInt(CI))) && - // Verify we are shifting out everything but the sign bit. - CI->getValue() == I.getType()->getPrimitiveSizeInBits() - 1) - return BinaryOperator::CreateAShr(X, CI); - - if (match(Op1, m_AShr(m_Value(X), m_ConstantInt(CI))) && - // Verify we are shifting out everything but the sign bit. - CI->getValue() == I.getType()->getPrimitiveSizeInBits() - 1) - return BinaryOperator::CreateLShr(X, CI); + const APInt *ShAmt; + if (match(Op1, m_LShr(m_Value(X), m_APInt(ShAmt))) && + *ShAmt == BitWidth - 1) { + Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1); + return BinaryOperator::CreateAShr(X, ShAmtOp); + } + if (match(Op1, m_AShr(m_Value(X), m_APInt(ShAmt))) && + *ShAmt == BitWidth - 1) { + Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1); + return BinaryOperator::CreateLShr(X, ShAmtOp); + } } // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known // zero. - APInt IntVal = C->getValue(); - if ((IntVal + 1).isPowerOf2()) { - unsigned BitWidth = I.getType()->getScalarSizeInBits(); + if ((*Op0C + 1).isPowerOf2()) { APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); computeKnownBits(&I, KnownZero, KnownOne, 0, &I); - if ((IntVal | KnownZero).isAllOnesValue()) { - return BinaryOperator::CreateXor(Op1, C); - } + if ((*Op0C | KnownZero).isAllOnesValue()) + return BinaryOperator::CreateXor(Op1, Op0); } } @@ -1632,6 +1688,17 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (Value *XNeg = dyn_castNegVal(X)) return BinaryOperator::CreateShl(XNeg, Y); + // Subtracting -1/0 is the same as adding 1/0: + // sub [nsw] Op0, sext(bool Y) -> add [nsw] Op0, zext(bool Y) + // 'nuw' is dropped in favor of the canonical form. + if (match(Op1, m_SExt(m_Value(Y))) && + Y->getType()->getScalarSizeInBits() == 1) { + Value *Zext = Builder->CreateZExt(Y, I.getType()); + BinaryOperator *Add = BinaryOperator::CreateAdd(Op0, Zext); + Add->setHasNoSignedWrap(I.hasNoSignedWrap()); + return Add; + } + // X - A*-B -> X + A*B // X - -A*B -> X + A*B Value *A, *B; @@ -1682,7 +1749,7 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = - SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) + SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // fsub nsz 0, X ==> fsub nsz -0.0, X diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 1a6459b..da5384a 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -98,12 +98,11 @@ Value *InstCombiner::SimplifyBSwap(BinaryOperator &I) { IntegerType *ITy = dyn_cast<IntegerType>(I.getType()); // Can't do vectors. - if (I.getType()->isVectorTy()) return nullptr; + if (I.getType()->isVectorTy()) + return nullptr; // Can only do bitwise ops. - unsigned Op = I.getOpcode(); - if (Op != Instruction::And && Op != Instruction::Or && - Op != Instruction::Xor) + if (!I.isBitwiseLogicOp()) return nullptr; Value *OldLHS = I.getOperand(0); @@ -132,14 +131,7 @@ Value *InstCombiner::SimplifyBSwap(BinaryOperator &I) { Value *NewRHS = IsBswapRHS ? IntrRHS->getOperand(0) : Builder->getInt(ConstRHS->getValue().byteSwap()); - Value *BinOp = nullptr; - if (Op == Instruction::And) - BinOp = Builder->CreateAnd(NewLHS, NewRHS); - else if (Op == Instruction::Or) - BinOp = Builder->CreateOr(NewLHS, NewRHS); - else //if (Op == Instruction::Xor) - BinOp = Builder->CreateXor(NewLHS, NewRHS); - + Value *BinOp = Builder->CreateBinOp(I.getOpcode(), NewLHS, NewRHS); Function *F = Intrinsic::getDeclaration(I.getModule(), Intrinsic::bswap, ITy); return Builder->CreateCall(F, BinOp); } @@ -283,51 +275,31 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op, } /// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise -/// (V < Lo || V >= Hi). In practice, we emit the more efficient -/// (V-Lo) \<u Hi-Lo. This method expects that Lo <= Hi. isSigned indicates -/// whether to treat the V, Lo and HI as signed or not. IB is the location to -/// insert new instructions. -Value *InstCombiner::InsertRangeTest(Value *V, Constant *Lo, Constant *Hi, +/// (V < Lo || V >= Hi). This method expects that Lo <= Hi. IsSigned indicates +/// whether to treat V, Lo, and Hi as signed or not. +Value *InstCombiner::insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, bool isSigned, bool Inside) { - assert(cast<ConstantInt>(ConstantExpr::getICmp((isSigned ? - ICmpInst::ICMP_SLE:ICmpInst::ICMP_ULE), Lo, Hi))->getZExtValue() && + assert((isSigned ? Lo.sle(Hi) : Lo.ule(Hi)) && "Lo is not <= Hi in range emission code!"); - if (Inside) { - if (Lo == Hi) // Trivially false. - return Builder->getFalse(); - - // V >= Min && V < Hi --> V < Hi - if (cast<ConstantInt>(Lo)->isMinValue(isSigned)) { - ICmpInst::Predicate pred = (isSigned ? - ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT); - return Builder->CreateICmp(pred, V, Hi); - } - - // Emit V-Lo <u Hi-Lo - Constant *NegLo = ConstantExpr::getNeg(Lo); - Value *Add = Builder->CreateAdd(V, NegLo, V->getName()+".off"); - Constant *UpperBound = ConstantExpr::getAdd(NegLo, Hi); - return Builder->CreateICmpULT(Add, UpperBound); - } - - if (Lo == Hi) // Trivially true. - return Builder->getTrue(); + Type *Ty = V->getType(); + if (Lo == Hi) + return Inside ? ConstantInt::getFalse(Ty) : ConstantInt::getTrue(Ty); - // V < Min || V >= Hi -> V > Hi-1 - Hi = SubOne(cast<ConstantInt>(Hi)); - if (cast<ConstantInt>(Lo)->isMinValue(isSigned)) { - ICmpInst::Predicate pred = (isSigned ? - ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT); - return Builder->CreateICmp(pred, V, Hi); + // V >= Min && V < Hi --> V < Hi + // V < Min || V >= Hi --> V >= Hi + ICmpInst::Predicate Pred = Inside ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE; + if (isSigned ? Lo.isMinSignedValue() : Lo.isMinValue()) { + Pred = isSigned ? ICmpInst::getSignedPredicate(Pred) : Pred; + return Builder->CreateICmp(Pred, V, ConstantInt::get(Ty, Hi)); } - // Emit V-Lo >u Hi-1-Lo - // Note that Hi has already had one subtracted from it, above. - ConstantInt *NegLo = cast<ConstantInt>(ConstantExpr::getNeg(Lo)); - Value *Add = Builder->CreateAdd(V, NegLo, V->getName()+".off"); - Constant *LowerBound = ConstantExpr::getAdd(NegLo, Hi); - return Builder->CreateICmpUGT(Add, LowerBound); + // V >= Lo && V < Hi --> V - Lo u< Hi - Lo + // V < Lo || V >= Hi --> V - Lo u>= Hi - Lo + Value *VMinusLo = + Builder->CreateSub(V, ConstantInt::get(Ty, Lo), V->getName() + ".off"); + Constant *HiMinusLo = ConstantInt::get(Ty, Hi - Lo); + return Builder->CreateICmp(Pred, VMinusLo, HiMinusLo); } /// Returns true iff Val consists of one contiguous run of 1s with any number @@ -524,53 +496,6 @@ static unsigned conjugateICmpMask(unsigned Mask) { return NewMask; } -/// Decompose an icmp into the form ((X & Y) pred Z) if possible. -/// The returned predicate is either == or !=. Returns false if -/// decomposition fails. -static bool decomposeBitTestICmp(const ICmpInst *I, ICmpInst::Predicate &Pred, - Value *&X, Value *&Y, Value *&Z) { - ConstantInt *C = dyn_cast<ConstantInt>(I->getOperand(1)); - if (!C) - return false; - - switch (I->getPredicate()) { - default: - return false; - case ICmpInst::ICMP_SLT: - // X < 0 is equivalent to (X & SignBit) != 0. - if (!C->isZero()) - return false; - Y = ConstantInt::get(I->getContext(), APInt::getSignBit(C->getBitWidth())); - Pred = ICmpInst::ICMP_NE; - break; - case ICmpInst::ICMP_SGT: - // X > -1 is equivalent to (X & SignBit) == 0. - if (!C->isAllOnesValue()) - return false; - Y = ConstantInt::get(I->getContext(), APInt::getSignBit(C->getBitWidth())); - Pred = ICmpInst::ICMP_EQ; - break; - case ICmpInst::ICMP_ULT: - // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0. - if (!C->getValue().isPowerOf2()) - return false; - Y = ConstantInt::get(I->getContext(), -C->getValue()); - Pred = ICmpInst::ICMP_EQ; - break; - case ICmpInst::ICMP_UGT: - // X >u 2^n-1 is equivalent to (X & ~(2^n-1)) != 0. - if (!(C->getValue() + 1).isPowerOf2()) - return false; - Y = ConstantInt::get(I->getContext(), ~C->getValue()); - Pred = ICmpInst::ICMP_NE; - break; - } - - X = I->getOperand(0); - Z = ConstantInt::getNullValue(C->getType()); - return true; -} - /// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) /// Return the set of pattern classes (from MaskedICmpType) /// that both LHS and RHS satisfy. @@ -1001,7 +926,8 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { if (LHSCst == SubOne(RHSCst)) // (X != 13 & X u< 14) -> X < 13 return Builder->CreateICmpULT(Val, LHSCst); if (LHSCst->isNullValue()) // (X != 0 & X u< 14) -> X-1 u< 13 - return InsertRangeTest(Val, AddOne(LHSCst), RHSCst, false, true); + return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), + false, true); break; // (X != 13 & X u< 15) -> no change case ICmpInst::ICMP_SLT: if (LHSCst == SubOne(RHSCst)) // (X != 13 & X s< 14) -> X < 13 @@ -1065,7 +991,8 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { return Builder->CreateICmp(LHSCC, Val, RHSCst); break; // (X u> 13 & X != 15) -> no change case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1 - return InsertRangeTest(Val, AddOne(LHSCst), RHSCst, false, true); + return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), + false, true); case ICmpInst::ICMP_SLT: // (X u> 13 & X s< 15) -> no change break; } @@ -1083,7 +1010,8 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { return Builder->CreateICmp(LHSCC, Val, RHSCst); break; // (X s> 13 & X != 15) -> no change case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1 - return InsertRangeTest(Val, AddOne(LHSCst), RHSCst, true, true); + return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), + true, true); case ICmpInst::ICMP_ULT: // (X s> 13 & X u< 15) -> no change break; } @@ -1170,34 +1098,73 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I, return BinaryOperator::CreateNot(LogicOp); } - // De Morgan's Law in disguise: - // (zext(bool A) ^ 1) & (zext(bool B) ^ 1) -> zext(~(A | B)) - // (zext(bool A) ^ 1) | (zext(bool B) ^ 1) -> zext(~(A & B)) - Value *A = nullptr; - Value *B = nullptr; - ConstantInt *C1 = nullptr; - if (match(Op0, m_OneUse(m_Xor(m_ZExt(m_Value(A)), m_ConstantInt(C1)))) && - match(Op1, m_OneUse(m_Xor(m_ZExt(m_Value(B)), m_Specific(C1))))) { - // TODO: This check could be loosened to handle different type sizes. - // Alternatively, we could fix the definition of m_Not to recognize a not - // operation hidden by a zext? - if (A->getType()->isIntegerTy(1) && B->getType()->isIntegerTy(1) && - C1->isOne()) { - Value *LogicOp = Builder->CreateBinOp(Opcode, A, B, - I.getName() + ".demorgan"); - Value *Not = Builder->CreateNot(LogicOp); - return CastInst::CreateZExtOrBitCast(Not, I.getType()); + return nullptr; +} + +bool InstCombiner::shouldOptimizeCast(CastInst *CI) { + Value *CastSrc = CI->getOperand(0); + + // Noop casts and casts of constants should be eliminated trivially. + if (CI->getSrcTy() == CI->getDestTy() || isa<Constant>(CastSrc)) + return false; + + // If this cast is paired with another cast that can be eliminated, we prefer + // to have it eliminated. + if (const auto *PrecedingCI = dyn_cast<CastInst>(CastSrc)) + if (isEliminableCastPair(PrecedingCI, CI)) + return false; + + // If this is a vector sext from a compare, then we don't want to break the + // idiom where each element of the extended vector is either zero or all ones. + if (CI->getOpcode() == Instruction::SExt && + isa<CmpInst>(CastSrc) && CI->getDestTy()->isVectorTy()) + return false; + + return true; +} + +/// Fold {and,or,xor} (cast X), C. +static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast, + InstCombiner::BuilderTy *Builder) { + Constant *C; + if (!match(Logic.getOperand(1), m_Constant(C))) + return nullptr; + + auto LogicOpc = Logic.getOpcode(); + Type *DestTy = Logic.getType(); + Type *SrcTy = Cast->getSrcTy(); + + // If the first operand is bitcast, move the logic operation ahead of the + // bitcast (do the logic operation in the original type). This can eliminate + // bitcasts and allow combines that would otherwise be impeded by the bitcast. + Value *X; + if (match(Cast, m_BitCast(m_Value(X)))) { + Value *NewConstant = ConstantExpr::getBitCast(C, SrcTy); + Value *NewOp = Builder->CreateBinOp(LogicOpc, X, NewConstant); + return CastInst::CreateBitOrPointerCast(NewOp, DestTy); + } + + // Similarly, move the logic operation ahead of a zext if the constant is + // unchanged in the smaller source type. Performing the logic in a smaller + // type may provide more information to later folds, and the smaller logic + // instruction may be cheaper (particularly in the case of vectors). + if (match(Cast, m_OneUse(m_ZExt(m_Value(X))))) { + Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy); + Constant *ZextTruncC = ConstantExpr::getZExt(TruncC, DestTy); + if (ZextTruncC == C) { + // LogicOpc (zext X), C --> zext (LogicOpc X, C) + Value *NewOp = Builder->CreateBinOp(LogicOpc, X, TruncC); + return new ZExtInst(NewOp, DestTy); } } return nullptr; } +/// Fold {and,or,xor} (cast X), Y. Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { auto LogicOpc = I.getOpcode(); - assert((LogicOpc == Instruction::And || LogicOpc == Instruction::Or || - LogicOpc == Instruction::Xor) && - "Unexpected opcode for bitwise logic folding"); + assert(I.isBitwiseLogicOp() && "Unexpected opcode for bitwise logic folding"); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); CastInst *Cast0 = dyn_cast<CastInst>(Op0); @@ -1211,18 +1178,8 @@ Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { if (!SrcTy->isIntOrIntVectorTy()) return nullptr; - // If one operand is a bitcast and the other is a constant, move the logic - // operation ahead of the bitcast. That is, do the logic operation in the - // original type. This can eliminate useless bitcasts and allow normal - // combines that would otherwise be impeded by the bitcast. Canonicalization - // ensures that if there is a constant operand, it will be the second operand. - Value *BC = nullptr; - Constant *C = nullptr; - if ((match(Op0, m_BitCast(m_Value(BC))) && match(Op1, m_Constant(C)))) { - Value *NewConstant = ConstantExpr::getBitCast(C, SrcTy); - Value *NewOp = Builder->CreateBinOp(LogicOpc, BC, NewConstant, I.getName()); - return CastInst::CreateBitOrPointerCast(NewOp, DestTy); - } + if (Instruction *Ret = foldLogicCastConstant(I, Cast0, Builder)) + return Ret; CastInst *Cast1 = dyn_cast<CastInst>(Op1); if (!Cast1) @@ -1237,12 +1194,8 @@ Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { Value *Cast0Src = Cast0->getOperand(0); Value *Cast1Src = Cast1->getOperand(0); - // fold (logic (cast A), (cast B)) -> (cast (logic A, B)) - - // Only do this if the casts both really cause code to be generated. - if ((!isa<ICmpInst>(Cast0Src) || !isa<ICmpInst>(Cast1Src)) && - ShouldOptimizeCast(CastOpcode, Cast0Src, DestTy) && - ShouldOptimizeCast(CastOpcode, Cast1Src, DestTy)) { + // fold logic(cast(A), cast(B)) -> cast(logic(A, B)) + if (shouldOptimizeCast(Cast0) && shouldOptimizeCast(Cast1)) { Value *NewOp = Builder->CreateBinOp(LogicOpc, Cast0Src, Cast1Src, I.getName()); return CastInst::Create(CastOpcode, NewOp, DestTy); @@ -1301,10 +1254,13 @@ static Instruction *foldBoolSextMaskToSelect(BinaryOperator &I) { Value *Zero = Constant::getNullValue(Op0->getType()); return SelectInst::Create(X, Zero, Op1); } - + return nullptr; } +// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches +// here. We should standardize that construct where it is needed or choose some +// other way to ensure that commutated variants of patterns are not missed. Instruction *InstCombiner::visitAnd(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -1312,7 +1268,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyAndInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifyAndInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // (A|B)&(A|C) -> A|(B&C) etc @@ -1426,13 +1382,8 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { } } - // Try to fold constant and into select arguments. - if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = FoldOpIntoSelect(I, SI)) - return R; - if (isa<PHINode>(Op0)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) + return FoldedLogic; } if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder)) @@ -1503,8 +1454,9 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { return BinaryOperator::CreateAnd(A, B); // ((~A) ^ B) & (A | B) -> (A & B) + // ((~A) ^ B) & (B | A) -> (A & B) if (match(Op0, m_Xor(m_Not(m_Value(A)), m_Value(B))) && - match(Op1, m_Or(m_Specific(A), m_Specific(B)))) + match(Op1, m_c_Or(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateAnd(A, B); } @@ -1697,17 +1649,17 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Value *Mask = nullptr; Value *Masked = nullptr; if (LAnd->getOperand(0) == RAnd->getOperand(0) && - isKnownToBeAPowerOfTwo(LAnd->getOperand(1), DL, false, 0, AC, CxtI, - DT) && - isKnownToBeAPowerOfTwo(RAnd->getOperand(1), DL, false, 0, AC, CxtI, - DT)) { + isKnownToBeAPowerOfTwo(LAnd->getOperand(1), DL, false, 0, &AC, CxtI, + &DT) && + isKnownToBeAPowerOfTwo(RAnd->getOperand(1), DL, false, 0, &AC, CxtI, + &DT)) { Mask = Builder->CreateOr(LAnd->getOperand(1), RAnd->getOperand(1)); Masked = Builder->CreateAnd(LAnd->getOperand(0), Mask); } else if (LAnd->getOperand(1) == RAnd->getOperand(1) && - isKnownToBeAPowerOfTwo(LAnd->getOperand(0), DL, false, 0, AC, - CxtI, DT) && - isKnownToBeAPowerOfTwo(RAnd->getOperand(0), DL, false, 0, AC, - CxtI, DT)) { + isKnownToBeAPowerOfTwo(LAnd->getOperand(0), DL, false, 0, &AC, + CxtI, &DT) && + isKnownToBeAPowerOfTwo(RAnd->getOperand(0), DL, false, 0, &AC, + CxtI, &DT)) { Mask = Builder->CreateOr(LAnd->getOperand(0), RAnd->getOperand(0)); Masked = Builder->CreateAnd(LAnd->getOperand(1), Mask); } @@ -1825,7 +1777,7 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // E.g. (icmp sgt x, n) | (icmp slt x, 0) --> icmp ugt x, n if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/true)) return V; - + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). if (!LHSCst || !RHSCst) return nullptr; @@ -1943,7 +1895,8 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // this can cause overflow. if (RHSCst->isMaxValue(false)) return LHS; - return InsertRangeTest(Val, LHSCst, AddOne(RHSCst), false, false); + return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1, + false, false); case ICmpInst::ICMP_SGT: // (X u< 13 | X s> 15) -> no change break; case ICmpInst::ICMP_NE: // (X u< 13 | X != 15) -> X != 15 @@ -1963,7 +1916,8 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // this can cause overflow. if (RHSCst->isMaxValue(true)) return LHS; - return InsertRangeTest(Val, LHSCst, AddOne(RHSCst), true, false); + return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1, + true, false); case ICmpInst::ICMP_UGT: // (X s< 13 | X u> 15) -> no change break; case ICmpInst::ICMP_NE: // (X s< 13 | X != 15) -> X != 15 @@ -2119,6 +2073,9 @@ Instruction *InstCombiner::FoldXorWithConstants(BinaryOperator &I, Value *Op, return nullptr; } +// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches +// here. We should standardize that construct where it is needed or choose some +// other way to ensure that commutated variants of patterns are not missed. Instruction *InstCombiner::visitOr(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -2126,7 +2083,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyOrInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifyOrInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // (A&B)|(A&C) -> A&(B|C) etc @@ -2163,14 +2120,8 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { Builder->getInt(C1->getValue() & ~RHS->getValue())); } - // Try to fold constant and into select arguments. - if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = FoldOpIntoSelect(I, SI)) - return R; - - if (isa<PHINode>(Op0)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) + return FoldedLogic; } // Given an OR instruction, check to see if this is a bswap. @@ -2208,14 +2159,17 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { match(Op1, m_Not(m_Specific(A)))) return BinaryOperator::CreateOr(Builder->CreateNot(A), B); - // (A & (~B)) | (A ^ B) -> (A ^ B) - if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && + // (A & ~B) | (A ^ B) -> (A ^ B) + // (~B & A) | (A ^ B) -> (A ^ B) + if (match(Op0, m_c_And(m_Value(A), m_Not(m_Value(B)))) && match(Op1, m_Xor(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateXor(A, B); - // (A ^ B) | ( A & (~B)) -> (A ^ B) - if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && - match(Op1, m_And(m_Specific(A), m_Not(m_Specific(B))))) + // Commute the 'or' operands. + // (A ^ B) | (A & ~B) -> (A ^ B) + // (A ^ B) | (~B & A) -> (A ^ B) + if (match(Op1, m_c_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op0, m_Xor(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateXor(A, B); // (A & C)|(B & D) @@ -2385,14 +2339,15 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { return BinaryOperator::CreateOr(Not, Op0); } - // (A & B) | ((~A) ^ B) -> (~A ^ B) - if (match(Op0, m_And(m_Value(A), m_Value(B))) && - match(Op1, m_Xor(m_Not(m_Specific(A)), m_Specific(B)))) - return BinaryOperator::CreateXor(Builder->CreateNot(A), B); - - // ((~A) ^ B) | (A & B) -> (~A ^ B) - if (match(Op0, m_Xor(m_Not(m_Value(A)), m_Value(B))) && - match(Op1, m_And(m_Specific(A), m_Specific(B)))) + // (A & B) | (~A ^ B) -> (~A ^ B) + // (A & B) | (B ^ ~A) -> (~A ^ B) + // (B & A) | (~A ^ B) -> (~A ^ B) + // (B & A) | (B ^ ~A) -> (~A ^ B) + // The match order is important: match the xor first because the 'not' + // operation defines 'A'. We do not need to match the xor as Op0 because the + // xor was canonicalized to Op1 above. + if (match(Op1, m_c_Xor(m_Not(m_Value(A)), m_Value(B))) && + match(Op0, m_c_And(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateXor(Builder->CreateNot(A), B); if (SwappedForXor) @@ -2472,6 +2427,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { return Changed ? &I : nullptr; } +// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches +// here. We should standardize that construct where it is needed or choose some +// other way to ensure that commutated variants of patterns are not missed. Instruction *InstCombiner::visitXor(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -2479,7 +2437,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyXorInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifyXorInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // (A&B)^(A&C) -> A&(B^C) etc @@ -2625,13 +2583,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } - // Try to fold constant and into select arguments. - if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = FoldOpIntoSelect(I, SI)) - return R; - if (isa<PHINode>(Op0)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) + return FoldedLogic; } BinaryOperator *Op1I = dyn_cast<BinaryOperator>(Op1); @@ -2694,20 +2647,22 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { return BinaryOperator::CreateXor(A, B); } // (A | ~B) ^ (~A | B) -> A ^ B - if (match(Op0I, m_Or(m_Value(A), m_Not(m_Value(B)))) && - match(Op1I, m_Or(m_Not(m_Specific(A)), m_Specific(B)))) { + // (~B | A) ^ (~A | B) -> A ^ B + if (match(Op0I, m_c_Or(m_Value(A), m_Not(m_Value(B)))) && + match(Op1I, m_Or(m_Not(m_Specific(A)), m_Specific(B)))) return BinaryOperator::CreateXor(A, B); - } + // (~A | B) ^ (A | ~B) -> A ^ B if (match(Op0I, m_Or(m_Not(m_Value(A)), m_Value(B))) && match(Op1I, m_Or(m_Specific(A), m_Not(m_Specific(B))))) { return BinaryOperator::CreateXor(A, B); } // (A & ~B) ^ (~A & B) -> A ^ B - if (match(Op0I, m_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op1I, m_And(m_Not(m_Specific(A)), m_Specific(B)))) { + // (~B & A) ^ (~A & B) -> A ^ B + if (match(Op0I, m_c_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op1I, m_And(m_Not(m_Specific(A)), m_Specific(B)))) return BinaryOperator::CreateXor(A, B); - } + // (~A & B) ^ (A & ~B) -> A ^ B if (match(Op0I, m_And(m_Not(m_Value(A)), m_Value(B))) && match(Op1I, m_And(m_Specific(A), m_Not(m_Specific(B))))) { @@ -2743,9 +2698,10 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { return BinaryOperator::CreateOr(A, B); } - Value *A = nullptr, *B = nullptr; - // (A & ~B) ^ (~A) -> ~(A & B) - if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && + // (A & ~B) ^ ~A -> ~(A & B) + // (~B & A) ^ ~A -> ~(A & B) + Value *A, *B; + if (match(Op0, m_c_And(m_Value(A), m_Not(m_Value(B)))) && match(Op1, m_Not(m_Specific(A)))) return BinaryOperator::CreateNot(Builder->CreateAnd(A, B)); diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 8acff91..2ef82ba 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -12,17 +12,47 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" -#include "llvm/IR/Dominators.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Statepoint.h" -#include "llvm/Transforms/Utils/BuildLibCalls.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SimplifyLibCalls.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <cstring> +#include <vector> + using namespace llvm; using namespace PatternMatch; @@ -79,8 +109,8 @@ static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) { } Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { - unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, MI, AC, DT); - unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, MI, AC, DT); + unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, MI, &AC, &DT); + unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, MI, &AC, &DT); unsigned MinAlign = std::min(DstAlign, SrcAlign); unsigned CopyAlign = MI->getAlignment(); @@ -162,10 +192,17 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { L->setAlignment(SrcAlign); if (CopyMD) L->setMetadata(LLVMContext::MD_tbaa, CopyMD); + MDNode *LoopMemParallelMD = + MI->getMetadata(LLVMContext::MD_mem_parallel_loop_access); + if (LoopMemParallelMD) + L->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD); + StoreInst *S = Builder->CreateStore(L, Dest, MI->isVolatile()); S->setAlignment(DstAlign); if (CopyMD) S->setMetadata(LLVMContext::MD_tbaa, CopyMD); + if (LoopMemParallelMD) + S->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD); // Set the size of the copy to 0, it will be deleted on the next iteration. MI->setArgOperand(2, Constant::getNullValue(MemOpLength->getType())); @@ -173,7 +210,7 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { } Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) { - unsigned Alignment = getKnownAlignment(MI->getDest(), DL, MI, AC, DT); + unsigned Alignment = getKnownAlignment(MI->getDest(), DL, MI, &AC, &DT); if (MI->getAlignment() < Alignment) { MI->setAlignment(ConstantInt::get(MI->getAlignmentType(), Alignment, false)); @@ -221,8 +258,7 @@ static Value *simplifyX86immShift(const IntrinsicInst &II, bool ShiftLeft = false; switch (II.getIntrinsicID()) { - default: - return nullptr; + default: llvm_unreachable("Unexpected intrinsic!"); case Intrinsic::x86_sse2_psra_d: case Intrinsic::x86_sse2_psra_w: case Intrinsic::x86_sse2_psrai_d: @@ -231,6 +267,16 @@ static Value *simplifyX86immShift(const IntrinsicInst &II, case Intrinsic::x86_avx2_psra_w: case Intrinsic::x86_avx2_psrai_d: case Intrinsic::x86_avx2_psrai_w: + case Intrinsic::x86_avx512_psra_q_128: + case Intrinsic::x86_avx512_psrai_q_128: + case Intrinsic::x86_avx512_psra_q_256: + case Intrinsic::x86_avx512_psrai_q_256: + case Intrinsic::x86_avx512_psra_d_512: + case Intrinsic::x86_avx512_psra_q_512: + case Intrinsic::x86_avx512_psra_w_512: + case Intrinsic::x86_avx512_psrai_d_512: + case Intrinsic::x86_avx512_psrai_q_512: + case Intrinsic::x86_avx512_psrai_w_512: LogicalShift = false; ShiftLeft = false; break; case Intrinsic::x86_sse2_psrl_d: @@ -245,6 +291,12 @@ static Value *simplifyX86immShift(const IntrinsicInst &II, case Intrinsic::x86_avx2_psrli_d: case Intrinsic::x86_avx2_psrli_q: case Intrinsic::x86_avx2_psrli_w: + case Intrinsic::x86_avx512_psrl_d_512: + case Intrinsic::x86_avx512_psrl_q_512: + case Intrinsic::x86_avx512_psrl_w_512: + case Intrinsic::x86_avx512_psrli_d_512: + case Intrinsic::x86_avx512_psrli_q_512: + case Intrinsic::x86_avx512_psrli_w_512: LogicalShift = true; ShiftLeft = false; break; case Intrinsic::x86_sse2_psll_d: @@ -259,6 +311,12 @@ static Value *simplifyX86immShift(const IntrinsicInst &II, case Intrinsic::x86_avx2_pslli_d: case Intrinsic::x86_avx2_pslli_q: case Intrinsic::x86_avx2_pslli_w: + case Intrinsic::x86_avx512_psll_d_512: + case Intrinsic::x86_avx512_psll_q_512: + case Intrinsic::x86_avx512_psll_w_512: + case Intrinsic::x86_avx512_pslli_d_512: + case Intrinsic::x86_avx512_pslli_q_512: + case Intrinsic::x86_avx512_pslli_w_512: LogicalShift = true; ShiftLeft = true; break; } @@ -334,10 +392,16 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, bool ShiftLeft = false; switch (II.getIntrinsicID()) { - default: - return nullptr; + default: llvm_unreachable("Unexpected intrinsic!"); case Intrinsic::x86_avx2_psrav_d: case Intrinsic::x86_avx2_psrav_d_256: + case Intrinsic::x86_avx512_psrav_q_128: + case Intrinsic::x86_avx512_psrav_q_256: + case Intrinsic::x86_avx512_psrav_d_512: + case Intrinsic::x86_avx512_psrav_q_512: + case Intrinsic::x86_avx512_psrav_w_128: + case Intrinsic::x86_avx512_psrav_w_256: + case Intrinsic::x86_avx512_psrav_w_512: LogicalShift = false; ShiftLeft = false; break; @@ -345,6 +409,11 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, case Intrinsic::x86_avx2_psrlv_d_256: case Intrinsic::x86_avx2_psrlv_q: case Intrinsic::x86_avx2_psrlv_q_256: + case Intrinsic::x86_avx512_psrlv_d_512: + case Intrinsic::x86_avx512_psrlv_q_512: + case Intrinsic::x86_avx512_psrlv_w_128: + case Intrinsic::x86_avx512_psrlv_w_256: + case Intrinsic::x86_avx512_psrlv_w_512: LogicalShift = true; ShiftLeft = false; break; @@ -352,6 +421,11 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, case Intrinsic::x86_avx2_psllv_d_256: case Intrinsic::x86_avx2_psllv_q: case Intrinsic::x86_avx2_psllv_q_256: + case Intrinsic::x86_avx512_psllv_d_512: + case Intrinsic::x86_avx512_psllv_q_512: + case Intrinsic::x86_avx512_psllv_w_128: + case Intrinsic::x86_avx512_psllv_w_256: + case Intrinsic::x86_avx512_psllv_w_512: LogicalShift = true; ShiftLeft = true; break; @@ -400,7 +474,7 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, // If all elements out of range or UNDEF, return vector of zeros/undefs. // ArithmeticShift should only hit this if they are all UNDEF. auto OutOfRange = [&](int Idx) { return (Idx < 0) || (BitWidth <= Idx); }; - if (llvm::all_of(ShiftAmts, OutOfRange)) { + if (all_of(ShiftAmts, OutOfRange)) { SmallVector<Constant *, 8> ConstantVec; for (int Idx : ShiftAmts) { if (Idx < 0) { @@ -547,7 +621,7 @@ static Value *simplifyX86extrq(IntrinsicInst &II, Value *Op0, // See if we're dealing with constant values. Constant *C0 = dyn_cast<Constant>(Op0); ConstantInt *CI0 = - C0 ? dyn_cast<ConstantInt>(C0->getAggregateElement((unsigned)0)) + C0 ? dyn_cast_or_null<ConstantInt>(C0->getAggregateElement((unsigned)0)) : nullptr; // Attempt to constant fold. @@ -630,7 +704,6 @@ static Value *simplifyX86extrq(IntrinsicInst &II, Value *Op0, static Value *simplifyX86insertq(IntrinsicInst &II, Value *Op0, Value *Op1, APInt APLength, APInt APIndex, InstCombiner::BuilderTy &Builder) { - // From AMD documentation: "The bit index and field length are each six bits // in length other bits of the field are ignored." APIndex = APIndex.zextOrTrunc(6); @@ -686,10 +759,10 @@ static Value *simplifyX86insertq(IntrinsicInst &II, Value *Op0, Value *Op1, Constant *C0 = dyn_cast<Constant>(Op0); Constant *C1 = dyn_cast<Constant>(Op1); ConstantInt *CI00 = - C0 ? dyn_cast<ConstantInt>(C0->getAggregateElement((unsigned)0)) + C0 ? dyn_cast_or_null<ConstantInt>(C0->getAggregateElement((unsigned)0)) : nullptr; ConstantInt *CI10 = - C1 ? dyn_cast<ConstantInt>(C1->getAggregateElement((unsigned)0)) + C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)0)) : nullptr; // Constant Fold - insert bottom Length bits starting at the Index'th bit. @@ -732,11 +805,11 @@ static Value *simplifyX86pshufb(const IntrinsicInst &II, auto *VecTy = cast<VectorType>(II.getType()); auto *MaskEltTy = Type::getInt32Ty(II.getContext()); unsigned NumElts = VecTy->getNumElements(); - assert((NumElts == 16 || NumElts == 32) && + assert((NumElts == 16 || NumElts == 32 || NumElts == 64) && "Unexpected number of elements in shuffle mask!"); // Construct a shuffle mask from constant integers or UNDEFs. - Constant *Indexes[32] = {NULL}; + Constant *Indexes[64] = {nullptr}; // Each byte in the shuffle control mask forms an index to permute the // corresponding byte in the destination operand. @@ -776,12 +849,15 @@ static Value *simplifyX86vpermilvar(const IntrinsicInst &II, if (!V) return nullptr; + auto *VecTy = cast<VectorType>(II.getType()); auto *MaskEltTy = Type::getInt32Ty(II.getContext()); - unsigned NumElts = cast<VectorType>(V->getType())->getNumElements(); - assert(NumElts == 8 || NumElts == 4 || NumElts == 2); + unsigned NumElts = VecTy->getVectorNumElements(); + bool IsPD = VecTy->getScalarType()->isDoubleTy(); + unsigned NumLaneElts = IsPD ? 2 : 4; + assert(NumElts == 16 || NumElts == 8 || NumElts == 4 || NumElts == 2); // Construct a shuffle mask from constant integers or UNDEFs. - Constant *Indexes[8] = {NULL}; + Constant *Indexes[16] = {nullptr}; // The intrinsics only read one or two bits, clear the rest. for (unsigned I = 0; I < NumElts; ++I) { @@ -799,18 +875,13 @@ static Value *simplifyX86vpermilvar(const IntrinsicInst &II, // The PD variants uses bit 1 to select per-lane element index, so // shift down to convert to generic shuffle mask index. - if (II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd || - II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd_256) + if (IsPD) Index = Index.lshr(1); // The _256 variants are a bit trickier since the mask bits always index // into the corresponding 128 half. In order to convert to a generic // shuffle, we have to make that explicit. - if ((II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_ps_256 || - II.getIntrinsicID() == Intrinsic::x86_avx_vpermilvar_pd_256) && - ((NumElts / 2) <= I)) { - Index += APInt(32, NumElts / 2); - } + Index += APInt(32, (I / NumLaneElts) * NumLaneElts); Indexes[I] = ConstantInt::get(MaskEltTy, Index); } @@ -831,10 +902,11 @@ static Value *simplifyX86vpermv(const IntrinsicInst &II, auto *VecTy = cast<VectorType>(II.getType()); auto *MaskEltTy = Type::getInt32Ty(II.getContext()); unsigned Size = VecTy->getNumElements(); - assert(Size == 8 && "Unexpected shuffle mask size"); + assert((Size == 4 || Size == 8 || Size == 16 || Size == 32 || Size == 64) && + "Unexpected shuffle mask size"); // Construct a shuffle mask from constant integers or UNDEFs. - Constant *Indexes[8] = {NULL}; + Constant *Indexes[64] = {nullptr}; for (unsigned I = 0; I < Size; ++I) { Constant *COp = V->getAggregateElement(I); @@ -846,8 +918,8 @@ static Value *simplifyX86vpermv(const IntrinsicInst &II, continue; } - APInt Index = cast<ConstantInt>(COp)->getValue(); - Index = Index.zextOrTrunc(32).getLoBits(3); + uint32_t Index = cast<ConstantInt>(COp)->getZExtValue(); + Index &= Size - 1; Indexes[I] = ConstantInt::get(MaskEltTy, Index); } @@ -962,6 +1034,36 @@ static Value *simplifyX86vpcom(const IntrinsicInst &II, return nullptr; } +// Emit a select instruction and appropriate bitcasts to help simplify +// masked intrinsics. +static Value *emitX86MaskSelect(Value *Mask, Value *Op0, Value *Op1, + InstCombiner::BuilderTy &Builder) { + unsigned VWidth = Op0->getType()->getVectorNumElements(); + + // If the mask is all ones we don't need the select. But we need to check + // only the bit thats will be used in case VWidth is less than 8. + if (auto *C = dyn_cast<ConstantInt>(Mask)) + if (C->getValue().zextOrTrunc(VWidth).isAllOnesValue()) + return Op0; + + auto *MaskTy = VectorType::get(Builder.getInt1Ty(), + cast<IntegerType>(Mask->getType())->getBitWidth()); + Mask = Builder.CreateBitCast(Mask, MaskTy); + + // If we have less than 8 elements, then the starting mask was an i8 and + // we need to extract down to the right number of elements. + if (VWidth < 8) { + uint32_t Indices[4]; + for (unsigned i = 0; i != VWidth; ++i) + Indices[i] = i; + Mask = Builder.CreateShuffleVector(Mask, Mask, + makeArrayRef(Indices, VWidth), + "extract"); + } + + return Builder.CreateSelect(Mask, Op0, Op1); +} + static Value *simplifyMinnumMaxnum(const IntrinsicInst &II) { Value *Arg0 = II.getArgOperand(0); Value *Arg1 = II.getArgOperand(1); @@ -1104,6 +1206,50 @@ static Instruction *simplifyMaskedScatter(IntrinsicInst &II, InstCombiner &IC) { return nullptr; } +static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombiner &IC) { + assert((II.getIntrinsicID() == Intrinsic::cttz || + II.getIntrinsicID() == Intrinsic::ctlz) && + "Expected cttz or ctlz intrinsic"); + Value *Op0 = II.getArgOperand(0); + // FIXME: Try to simplify vectors of integers. + auto *IT = dyn_cast<IntegerType>(Op0->getType()); + if (!IT) + return nullptr; + + unsigned BitWidth = IT->getBitWidth(); + APInt KnownZero(BitWidth, 0); + APInt KnownOne(BitWidth, 0); + IC.computeKnownBits(Op0, KnownZero, KnownOne, 0, &II); + + // Create a mask for bits above (ctlz) or below (cttz) the first known one. + bool IsTZ = II.getIntrinsicID() == Intrinsic::cttz; + unsigned NumMaskBits = IsTZ ? KnownOne.countTrailingZeros() + : KnownOne.countLeadingZeros(); + APInt Mask = IsTZ ? APInt::getLowBitsSet(BitWidth, NumMaskBits) + : APInt::getHighBitsSet(BitWidth, NumMaskBits); + + // If all bits above (ctlz) or below (cttz) the first known one are known + // zero, this value is constant. + // FIXME: This should be in InstSimplify because we're replacing an + // instruction with a constant. + if ((Mask & KnownZero) == Mask) { + auto *C = ConstantInt::get(IT, APInt(BitWidth, NumMaskBits)); + return IC.replaceInstUsesWith(II, C); + } + + // If the input to cttz/ctlz is known to be non-zero, + // then change the 'ZeroIsUndef' parameter to 'true' + // because we know the zero behavior can't affect the result. + if (KnownOne != 0 || isKnownNonZero(Op0, IC.getDataLayout())) { + if (!match(II.getArgOperand(1), m_One())) { + II.setOperand(1, IC.Builder->getTrue()); + return &II; + } + } + + return nullptr; +} + // TODO: If the x86 backend knew how to convert a bool vector mask back to an // XMM register mask efficiently, we could transform all x86 masked intrinsics // to LLVM masked intrinsics and remove the x86 masked intrinsic defs. @@ -1243,16 +1389,15 @@ Instruction *InstCombiner::visitVACopyInst(VACopyInst &I) { Instruction *InstCombiner::visitCallInst(CallInst &CI) { auto Args = CI.arg_operands(); if (Value *V = SimplifyCall(CI.getCalledValue(), Args.begin(), Args.end(), DL, - TLI, DT, AC)) + &TLI, &DT, &AC)) return replaceInstUsesWith(CI, V); - if (isFreeCall(&CI, TLI)) + if (isFreeCall(&CI, &TLI)) return visitFree(CI); // If the caller function is nounwind, mark the call as nounwind, even if the // callee isn't. - if (CI.getParent()->getParent()->doesNotThrow() && - !CI.doesNotThrow()) { + if (CI.getFunction()->doesNotThrow() && !CI.doesNotThrow()) { CI.setDoesNotThrow(); return &CI; } @@ -1323,26 +1468,15 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { APInt DemandedElts = APInt::getLowBitsSet(Width, DemandedWidth); return SimplifyDemandedVectorElts(Op, DemandedElts, UndefElts); }; - auto SimplifyDemandedVectorEltsHigh = [this](Value *Op, unsigned Width, - unsigned DemandedWidth) { - APInt UndefElts(Width, 0); - APInt DemandedElts = APInt::getHighBitsSet(Width, DemandedWidth); - return SimplifyDemandedVectorElts(Op, DemandedElts, UndefElts); - }; switch (II->getIntrinsicID()) { default: break; - case Intrinsic::objectsize: { - uint64_t Size; - if (getObjectSize(II->getArgOperand(0), Size, DL, TLI)) { - APInt APSize(II->getType()->getIntegerBitWidth(), Size); - // Equality check to be sure that `Size` can fit in a value of type - // `II->getType()` - if (APSize == Size) - return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), APSize)); - } + case Intrinsic::objectsize: + if (ConstantInt *N = + lowerObjectSizeCall(II, DL, &TLI, /*MustSucceed=*/false)) + return replaceInstUsesWith(CI, N); return nullptr; - } + case Intrinsic::bswap: { Value *IIOperand = II->getArgOperand(0); Value *X = nullptr; @@ -1397,41 +1531,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { II->getArgOperand(0)); } break; - case Intrinsic::cttz: { - // If all bits below the first known one are known zero, - // this value is constant. - IntegerType *IT = dyn_cast<IntegerType>(II->getArgOperand(0)->getType()); - // FIXME: Try to simplify vectors of integers. - if (!IT) break; - uint32_t BitWidth = IT->getBitWidth(); - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne, 0, II); - unsigned TrailingZeros = KnownOne.countTrailingZeros(); - APInt Mask(APInt::getLowBitsSet(BitWidth, TrailingZeros)); - if ((Mask & KnownZero) == Mask) - return replaceInstUsesWith(CI, ConstantInt::get(IT, - APInt(BitWidth, TrailingZeros))); - - } - break; - case Intrinsic::ctlz: { - // If all bits above the first known one are known zero, - // this value is constant. - IntegerType *IT = dyn_cast<IntegerType>(II->getArgOperand(0)->getType()); - // FIXME: Try to simplify vectors of integers. - if (!IT) break; - uint32_t BitWidth = IT->getBitWidth(); - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(II->getArgOperand(0), KnownZero, KnownOne, 0, II); - unsigned LeadingZeros = KnownOne.countLeadingZeros(); - APInt Mask(APInt::getHighBitsSet(BitWidth, LeadingZeros)); - if ((Mask & KnownZero) == Mask) - return replaceInstUsesWith(CI, ConstantInt::get(IT, - APInt(BitWidth, LeadingZeros))); - } + case Intrinsic::cttz: + case Intrinsic::ctlz: + if (auto *I = foldCttzCtlz(*II, *this)) + return I; break; case Intrinsic::uadd_with_overflow: @@ -1446,7 +1550,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { II->setArgOperand(1, LHS); return II; } - // fall through + LLVM_FALLTHROUGH; case Intrinsic::usub_with_overflow: case Intrinsic::ssub_with_overflow: { @@ -1477,11 +1581,77 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, V); break; } + case Intrinsic::fma: + case Intrinsic::fmuladd: { + Value *Src0 = II->getArgOperand(0); + Value *Src1 = II->getArgOperand(1); + + // Canonicalize constants into the RHS. + if (isa<Constant>(Src0) && !isa<Constant>(Src1)) { + II->setArgOperand(0, Src1); + II->setArgOperand(1, Src0); + std::swap(Src0, Src1); + } + + Value *LHS = nullptr; + Value *RHS = nullptr; + + // fma fneg(x), fneg(y), z -> fma x, y, z + if (match(Src0, m_FNeg(m_Value(LHS))) && + match(Src1, m_FNeg(m_Value(RHS)))) { + II->setArgOperand(0, LHS); + II->setArgOperand(1, RHS); + return II; + } + + // fma fabs(x), fabs(x), z -> fma x, x, z + if (match(Src0, m_Intrinsic<Intrinsic::fabs>(m_Value(LHS))) && + match(Src1, m_Intrinsic<Intrinsic::fabs>(m_Value(RHS))) && LHS == RHS) { + II->setArgOperand(0, LHS); + II->setArgOperand(1, RHS); + return II; + } + + // fma x, 1, z -> fadd x, z + if (match(Src1, m_FPOne())) { + Instruction *RI = BinaryOperator::CreateFAdd(Src0, II->getArgOperand(2)); + RI->copyFastMathFlags(II); + return RI; + } + + break; + } + case Intrinsic::fabs: { + Value *Cond; + Constant *LHS, *RHS; + if (match(II->getArgOperand(0), + m_Select(m_Value(Cond), m_Constant(LHS), m_Constant(RHS)))) { + CallInst *Call0 = Builder->CreateCall(II->getCalledFunction(), {LHS}); + CallInst *Call1 = Builder->CreateCall(II->getCalledFunction(), {RHS}); + return SelectInst::Create(Cond, Call0, Call1); + } + + break; + } + case Intrinsic::cos: + case Intrinsic::amdgcn_cos: { + Value *SrcSrc; + Value *Src = II->getArgOperand(0); + if (match(Src, m_FNeg(m_Value(SrcSrc))) || + match(Src, m_Intrinsic<Intrinsic::fabs>(m_Value(SrcSrc)))) { + // cos(-x) -> cos(x) + // cos(fabs(x)) -> cos(x) + II->setArgOperand(0, SrcSrc); + return II; + } + + break; + } case Intrinsic::ppc_altivec_lvx: case Intrinsic::ppc_altivec_lvxl: // Turn PPC lvx -> load if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, II, AC, DT) >= - 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, II, &AC, + &DT) >= 16) { Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); return new LoadInst(Ptr); @@ -1497,8 +1667,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::ppc_altivec_stvx: case Intrinsic::ppc_altivec_stvxl: // Turn stvx -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL, II, AC, DT) >= - 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL, II, &AC, + &DT) >= 16) { Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); Value *Ptr = Builder->CreateBitCast(II->getArgOperand(1), OpPtrTy); @@ -1514,8 +1684,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } case Intrinsic::ppc_qpx_qvlfs: // Turn PPC QPX qvlfs -> load if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, II, AC, DT) >= - 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, II, &AC, + &DT) >= 16) { Type *VTy = VectorType::get(Builder->getFloatTy(), II->getType()->getVectorNumElements()); Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), @@ -1526,8 +1696,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; case Intrinsic::ppc_qpx_qvlfd: // Turn PPC QPX qvlfd -> load if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 32, DL, II, AC, DT) >= - 32) { + if (getOrEnforceKnownAlignment(II->getArgOperand(0), 32, DL, II, &AC, + &DT) >= 32) { Value *Ptr = Builder->CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); return new LoadInst(Ptr); @@ -1535,8 +1705,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; case Intrinsic::ppc_qpx_qvstfs: // Turn PPC QPX qvstfs -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL, II, AC, DT) >= - 16) { + if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL, II, &AC, + &DT) >= 16) { Type *VTy = VectorType::get(Builder->getFloatTy(), II->getArgOperand(0)->getType()->getVectorNumElements()); Value *TOp = Builder->CreateFPTrunc(II->getArgOperand(0), VTy); @@ -1547,8 +1717,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; case Intrinsic::ppc_qpx_qvstfd: // Turn PPC QPX qvstfd -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(1), 32, DL, II, AC, DT) >= - 32) { + if (getOrEnforceKnownAlignment(II->getArgOperand(1), 32, DL, II, &AC, + &DT) >= 32) { Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); Value *Ptr = Builder->CreateBitCast(II->getArgOperand(1), OpPtrTy); @@ -1607,7 +1777,23 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_sse2_cvtsd2si: case Intrinsic::x86_sse2_cvtsd2si64: case Intrinsic::x86_sse2_cvttsd2si: - case Intrinsic::x86_sse2_cvttsd2si64: { + case Intrinsic::x86_sse2_cvttsd2si64: + case Intrinsic::x86_avx512_vcvtss2si32: + case Intrinsic::x86_avx512_vcvtss2si64: + case Intrinsic::x86_avx512_vcvtss2usi32: + case Intrinsic::x86_avx512_vcvtss2usi64: + case Intrinsic::x86_avx512_vcvtsd2si32: + case Intrinsic::x86_avx512_vcvtsd2si64: + case Intrinsic::x86_avx512_vcvtsd2usi32: + case Intrinsic::x86_avx512_vcvtsd2usi64: + case Intrinsic::x86_avx512_cvttss2si: + case Intrinsic::x86_avx512_cvttss2si64: + case Intrinsic::x86_avx512_cvttss2usi: + case Intrinsic::x86_avx512_cvttss2usi64: + case Intrinsic::x86_avx512_cvttsd2si: + case Intrinsic::x86_avx512_cvttsd2si64: + case Intrinsic::x86_avx512_cvttsd2usi: + case Intrinsic::x86_avx512_cvttsd2usi64: { // These intrinsics only demand the 0th element of their input vectors. If // we can simplify the input based on that, do so now. Value *Arg = II->getArgOperand(0); @@ -1654,7 +1840,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_sse2_ucomigt_sd: case Intrinsic::x86_sse2_ucomile_sd: case Intrinsic::x86_sse2_ucomilt_sd: - case Intrinsic::x86_sse2_ucomineq_sd: { + case Intrinsic::x86_sse2_ucomineq_sd: + case Intrinsic::x86_avx512_vcomi_ss: + case Intrinsic::x86_avx512_vcomi_sd: + case Intrinsic::x86_avx512_mask_cmp_ss: + case Intrinsic::x86_avx512_mask_cmp_sd: { // These intrinsics only demand the 0th element of their input vectors. If // we can simplify the input based on that, do so now. bool MadeChange = false; @@ -1674,50 +1864,155 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } - case Intrinsic::x86_sse_add_ss: - case Intrinsic::x86_sse_sub_ss: - case Intrinsic::x86_sse_mul_ss: - case Intrinsic::x86_sse_div_ss: + case Intrinsic::x86_avx512_mask_add_ps_512: + case Intrinsic::x86_avx512_mask_div_ps_512: + case Intrinsic::x86_avx512_mask_mul_ps_512: + case Intrinsic::x86_avx512_mask_sub_ps_512: + case Intrinsic::x86_avx512_mask_add_pd_512: + case Intrinsic::x86_avx512_mask_div_pd_512: + case Intrinsic::x86_avx512_mask_mul_pd_512: + case Intrinsic::x86_avx512_mask_sub_pd_512: + // If the rounding mode is CUR_DIRECTION(4) we can turn these into regular + // IR operations. + if (auto *R = dyn_cast<ConstantInt>(II->getArgOperand(4))) { + if (R->getValue() == 4) { + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + + Value *V; + switch (II->getIntrinsicID()) { + default: llvm_unreachable("Case stmts out of sync!"); + case Intrinsic::x86_avx512_mask_add_ps_512: + case Intrinsic::x86_avx512_mask_add_pd_512: + V = Builder->CreateFAdd(Arg0, Arg1); + break; + case Intrinsic::x86_avx512_mask_sub_ps_512: + case Intrinsic::x86_avx512_mask_sub_pd_512: + V = Builder->CreateFSub(Arg0, Arg1); + break; + case Intrinsic::x86_avx512_mask_mul_ps_512: + case Intrinsic::x86_avx512_mask_mul_pd_512: + V = Builder->CreateFMul(Arg0, Arg1); + break; + case Intrinsic::x86_avx512_mask_div_ps_512: + case Intrinsic::x86_avx512_mask_div_pd_512: + V = Builder->CreateFDiv(Arg0, Arg1); + break; + } + + // Create a select for the masking. + V = emitX86MaskSelect(II->getArgOperand(3), V, II->getArgOperand(2), + *Builder); + return replaceInstUsesWith(*II, V); + } + } + break; + + case Intrinsic::x86_avx512_mask_add_ss_round: + case Intrinsic::x86_avx512_mask_div_ss_round: + case Intrinsic::x86_avx512_mask_mul_ss_round: + case Intrinsic::x86_avx512_mask_sub_ss_round: + case Intrinsic::x86_avx512_mask_add_sd_round: + case Intrinsic::x86_avx512_mask_div_sd_round: + case Intrinsic::x86_avx512_mask_mul_sd_round: + case Intrinsic::x86_avx512_mask_sub_sd_round: + // If the rounding mode is CUR_DIRECTION(4) we can turn these into regular + // IR operations. + if (auto *R = dyn_cast<ConstantInt>(II->getArgOperand(4))) { + if (R->getValue() == 4) { + // Extract the element as scalars. + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + Value *LHS = Builder->CreateExtractElement(Arg0, (uint64_t)0); + Value *RHS = Builder->CreateExtractElement(Arg1, (uint64_t)0); + + Value *V; + switch (II->getIntrinsicID()) { + default: llvm_unreachable("Case stmts out of sync!"); + case Intrinsic::x86_avx512_mask_add_ss_round: + case Intrinsic::x86_avx512_mask_add_sd_round: + V = Builder->CreateFAdd(LHS, RHS); + break; + case Intrinsic::x86_avx512_mask_sub_ss_round: + case Intrinsic::x86_avx512_mask_sub_sd_round: + V = Builder->CreateFSub(LHS, RHS); + break; + case Intrinsic::x86_avx512_mask_mul_ss_round: + case Intrinsic::x86_avx512_mask_mul_sd_round: + V = Builder->CreateFMul(LHS, RHS); + break; + case Intrinsic::x86_avx512_mask_div_ss_round: + case Intrinsic::x86_avx512_mask_div_sd_round: + V = Builder->CreateFDiv(LHS, RHS); + break; + } + + // Handle the masking aspect of the intrinsic. + Value *Mask = II->getArgOperand(3); + auto *C = dyn_cast<ConstantInt>(Mask); + // We don't need a select if we know the mask bit is a 1. + if (!C || !C->getValue()[0]) { + // Cast the mask to an i1 vector and then extract the lowest element. + auto *MaskTy = VectorType::get(Builder->getInt1Ty(), + cast<IntegerType>(Mask->getType())->getBitWidth()); + Mask = Builder->CreateBitCast(Mask, MaskTy); + Mask = Builder->CreateExtractElement(Mask, (uint64_t)0); + // Extract the lowest element from the passthru operand. + Value *Passthru = Builder->CreateExtractElement(II->getArgOperand(2), + (uint64_t)0); + V = Builder->CreateSelect(Mask, V, Passthru); + } + + // Insert the result back into the original argument 0. + V = Builder->CreateInsertElement(Arg0, V, (uint64_t)0); + + return replaceInstUsesWith(*II, V); + } + } + LLVM_FALLTHROUGH; + + // X86 scalar intrinsics simplified with SimplifyDemandedVectorElts. + case Intrinsic::x86_avx512_mask_max_ss_round: + case Intrinsic::x86_avx512_mask_min_ss_round: + case Intrinsic::x86_avx512_mask_max_sd_round: + case Intrinsic::x86_avx512_mask_min_sd_round: + case Intrinsic::x86_avx512_mask_vfmadd_ss: + case Intrinsic::x86_avx512_mask_vfmadd_sd: + case Intrinsic::x86_avx512_maskz_vfmadd_ss: + case Intrinsic::x86_avx512_maskz_vfmadd_sd: + case Intrinsic::x86_avx512_mask3_vfmadd_ss: + case Intrinsic::x86_avx512_mask3_vfmadd_sd: + case Intrinsic::x86_avx512_mask3_vfmsub_ss: + case Intrinsic::x86_avx512_mask3_vfmsub_sd: + case Intrinsic::x86_avx512_mask3_vfnmsub_ss: + case Intrinsic::x86_avx512_mask3_vfnmsub_sd: + case Intrinsic::x86_fma_vfmadd_ss: + case Intrinsic::x86_fma_vfmsub_ss: + case Intrinsic::x86_fma_vfnmadd_ss: + case Intrinsic::x86_fma_vfnmsub_ss: + case Intrinsic::x86_fma_vfmadd_sd: + case Intrinsic::x86_fma_vfmsub_sd: + case Intrinsic::x86_fma_vfnmadd_sd: + case Intrinsic::x86_fma_vfnmsub_sd: + case Intrinsic::x86_sse_cmp_ss: case Intrinsic::x86_sse_min_ss: case Intrinsic::x86_sse_max_ss: - case Intrinsic::x86_sse_cmp_ss: - case Intrinsic::x86_sse2_add_sd: - case Intrinsic::x86_sse2_sub_sd: - case Intrinsic::x86_sse2_mul_sd: - case Intrinsic::x86_sse2_div_sd: + case Intrinsic::x86_sse2_cmp_sd: case Intrinsic::x86_sse2_min_sd: case Intrinsic::x86_sse2_max_sd: - case Intrinsic::x86_sse2_cmp_sd: { - // These intrinsics only demand the lowest element of the second input - // vector. - Value *Arg1 = II->getArgOperand(1); - unsigned VWidth = Arg1->getType()->getVectorNumElements(); - if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, 1)) { - II->setArgOperand(1, V); - return II; - } - break; - } - case Intrinsic::x86_sse41_round_ss: - case Intrinsic::x86_sse41_round_sd: { - // These intrinsics demand the upper elements of the first input vector and - // the lowest element of the second input vector. - bool MadeChange = false; - Value *Arg0 = II->getArgOperand(0); - Value *Arg1 = II->getArgOperand(1); - unsigned VWidth = Arg0->getType()->getVectorNumElements(); - if (Value *V = SimplifyDemandedVectorEltsHigh(Arg0, VWidth, VWidth - 1)) { - II->setArgOperand(0, V); - MadeChange = true; - } - if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, 1)) { - II->setArgOperand(1, V); - MadeChange = true; - } - if (MadeChange) - return II; - break; + case Intrinsic::x86_sse41_round_sd: + case Intrinsic::x86_xop_vfrcz_ss: + case Intrinsic::x86_xop_vfrcz_sd: { + unsigned VWidth = II->getType()->getVectorNumElements(); + APInt UndefElts(VWidth, 0); + APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) { + if (V != II) + return replaceInstUsesWith(*II, V); + return II; + } + break; } // Constant fold ashr( <A x Bi>, Ci ). @@ -1727,18 +2022,29 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_sse2_psrai_w: case Intrinsic::x86_avx2_psrai_d: case Intrinsic::x86_avx2_psrai_w: + case Intrinsic::x86_avx512_psrai_q_128: + case Intrinsic::x86_avx512_psrai_q_256: + case Intrinsic::x86_avx512_psrai_d_512: + case Intrinsic::x86_avx512_psrai_q_512: + case Intrinsic::x86_avx512_psrai_w_512: case Intrinsic::x86_sse2_psrli_d: case Intrinsic::x86_sse2_psrli_q: case Intrinsic::x86_sse2_psrli_w: case Intrinsic::x86_avx2_psrli_d: case Intrinsic::x86_avx2_psrli_q: case Intrinsic::x86_avx2_psrli_w: + case Intrinsic::x86_avx512_psrli_d_512: + case Intrinsic::x86_avx512_psrli_q_512: + case Intrinsic::x86_avx512_psrli_w_512: case Intrinsic::x86_sse2_pslli_d: case Intrinsic::x86_sse2_pslli_q: case Intrinsic::x86_sse2_pslli_w: case Intrinsic::x86_avx2_pslli_d: case Intrinsic::x86_avx2_pslli_q: case Intrinsic::x86_avx2_pslli_w: + case Intrinsic::x86_avx512_pslli_d_512: + case Intrinsic::x86_avx512_pslli_q_512: + case Intrinsic::x86_avx512_pslli_w_512: if (Value *V = simplifyX86immShift(*II, *Builder)) return replaceInstUsesWith(*II, V); break; @@ -1747,18 +2053,29 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_sse2_psra_w: case Intrinsic::x86_avx2_psra_d: case Intrinsic::x86_avx2_psra_w: + case Intrinsic::x86_avx512_psra_q_128: + case Intrinsic::x86_avx512_psra_q_256: + case Intrinsic::x86_avx512_psra_d_512: + case Intrinsic::x86_avx512_psra_q_512: + case Intrinsic::x86_avx512_psra_w_512: case Intrinsic::x86_sse2_psrl_d: case Intrinsic::x86_sse2_psrl_q: case Intrinsic::x86_sse2_psrl_w: case Intrinsic::x86_avx2_psrl_d: case Intrinsic::x86_avx2_psrl_q: case Intrinsic::x86_avx2_psrl_w: + case Intrinsic::x86_avx512_psrl_d_512: + case Intrinsic::x86_avx512_psrl_q_512: + case Intrinsic::x86_avx512_psrl_w_512: case Intrinsic::x86_sse2_psll_d: case Intrinsic::x86_sse2_psll_q: case Intrinsic::x86_sse2_psll_w: case Intrinsic::x86_avx2_psll_d: case Intrinsic::x86_avx2_psll_q: - case Intrinsic::x86_avx2_psll_w: { + case Intrinsic::x86_avx2_psll_w: + case Intrinsic::x86_avx512_psll_d_512: + case Intrinsic::x86_avx512_psll_q_512: + case Intrinsic::x86_avx512_psll_w_512: { if (Value *V = simplifyX86immShift(*II, *Builder)) return replaceInstUsesWith(*II, V); @@ -1780,16 +2097,50 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_psllv_d_256: case Intrinsic::x86_avx2_psllv_q: case Intrinsic::x86_avx2_psllv_q_256: + case Intrinsic::x86_avx512_psllv_d_512: + case Intrinsic::x86_avx512_psllv_q_512: + case Intrinsic::x86_avx512_psllv_w_128: + case Intrinsic::x86_avx512_psllv_w_256: + case Intrinsic::x86_avx512_psllv_w_512: case Intrinsic::x86_avx2_psrav_d: case Intrinsic::x86_avx2_psrav_d_256: + case Intrinsic::x86_avx512_psrav_q_128: + case Intrinsic::x86_avx512_psrav_q_256: + case Intrinsic::x86_avx512_psrav_d_512: + case Intrinsic::x86_avx512_psrav_q_512: + case Intrinsic::x86_avx512_psrav_w_128: + case Intrinsic::x86_avx512_psrav_w_256: + case Intrinsic::x86_avx512_psrav_w_512: case Intrinsic::x86_avx2_psrlv_d: case Intrinsic::x86_avx2_psrlv_d_256: case Intrinsic::x86_avx2_psrlv_q: case Intrinsic::x86_avx2_psrlv_q_256: + case Intrinsic::x86_avx512_psrlv_d_512: + case Intrinsic::x86_avx512_psrlv_q_512: + case Intrinsic::x86_avx512_psrlv_w_128: + case Intrinsic::x86_avx512_psrlv_w_256: + case Intrinsic::x86_avx512_psrlv_w_512: if (Value *V = simplifyX86varShift(*II, *Builder)) return replaceInstUsesWith(*II, V); break; + case Intrinsic::x86_sse2_pmulu_dq: + case Intrinsic::x86_sse41_pmuldq: + case Intrinsic::x86_avx2_pmul_dq: + case Intrinsic::x86_avx2_pmulu_dq: + case Intrinsic::x86_avx512_pmul_dq_512: + case Intrinsic::x86_avx512_pmulu_dq_512: { + unsigned VWidth = II->getType()->getVectorNumElements(); + APInt UndefElts(VWidth, 0); + APInt DemandedElts = APInt::getAllOnesValue(VWidth); + if (Value *V = SimplifyDemandedVectorElts(II, DemandedElts, UndefElts)) { + if (V != II) + return replaceInstUsesWith(*II, V); + return II; + } + break; + } + case Intrinsic::x86_sse41_insertps: if (Value *V = simplifyX86insertps(*II, *Builder)) return replaceInstUsesWith(*II, V); @@ -1807,10 +2158,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // See if we're dealing with constant values. Constant *C1 = dyn_cast<Constant>(Op1); ConstantInt *CILength = - C1 ? dyn_cast<ConstantInt>(C1->getAggregateElement((unsigned)0)) + C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)0)) : nullptr; ConstantInt *CIIndex = - C1 ? dyn_cast<ConstantInt>(C1->getAggregateElement((unsigned)1)) + C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)1)) : nullptr; // Attempt to simplify to a constant, shuffle vector or EXTRQI call. @@ -1870,7 +2221,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // See if we're dealing with constant values. Constant *C1 = dyn_cast<Constant>(Op1); ConstantInt *CI11 = - C1 ? dyn_cast<ConstantInt>(C1->getAggregateElement((unsigned)1)) + C1 ? dyn_cast_or_null<ConstantInt>(C1->getAggregateElement((unsigned)1)) : nullptr; // Attempt to simplify to a constant, shuffle vector or INSERTQI call. @@ -1964,14 +2315,17 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_ssse3_pshuf_b_128: case Intrinsic::x86_avx2_pshuf_b: + case Intrinsic::x86_avx512_pshuf_b_512: if (Value *V = simplifyX86pshufb(*II, *Builder)) return replaceInstUsesWith(*II, V); break; case Intrinsic::x86_avx_vpermilvar_ps: case Intrinsic::x86_avx_vpermilvar_ps_256: + case Intrinsic::x86_avx512_vpermilvar_ps_512: case Intrinsic::x86_avx_vpermilvar_pd: case Intrinsic::x86_avx_vpermilvar_pd_256: + case Intrinsic::x86_avx512_vpermilvar_pd_512: if (Value *V = simplifyX86vpermilvar(*II, *Builder)) return replaceInstUsesWith(*II, V); break; @@ -1982,6 +2336,28 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, V); break; + case Intrinsic::x86_avx512_mask_permvar_df_256: + case Intrinsic::x86_avx512_mask_permvar_df_512: + case Intrinsic::x86_avx512_mask_permvar_di_256: + case Intrinsic::x86_avx512_mask_permvar_di_512: + case Intrinsic::x86_avx512_mask_permvar_hi_128: + case Intrinsic::x86_avx512_mask_permvar_hi_256: + case Intrinsic::x86_avx512_mask_permvar_hi_512: + case Intrinsic::x86_avx512_mask_permvar_qi_128: + case Intrinsic::x86_avx512_mask_permvar_qi_256: + case Intrinsic::x86_avx512_mask_permvar_qi_512: + case Intrinsic::x86_avx512_mask_permvar_sf_256: + case Intrinsic::x86_avx512_mask_permvar_sf_512: + case Intrinsic::x86_avx512_mask_permvar_si_256: + case Intrinsic::x86_avx512_mask_permvar_si_512: + if (Value *V = simplifyX86vpermv(*II, *Builder)) { + // We simplified the permuting, now create a select for the masking. + V = emitX86MaskSelect(II->getArgOperand(3), V, II->getArgOperand(2), + *Builder); + return replaceInstUsesWith(*II, V); + } + break; + case Intrinsic::x86_avx_vperm2f128_pd_256: case Intrinsic::x86_avx_vperm2f128_ps_256: case Intrinsic::x86_avx_vperm2f128_si_256: @@ -2104,7 +2480,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::arm_neon_vst2lane: case Intrinsic::arm_neon_vst3lane: case Intrinsic::arm_neon_vst4lane: { - unsigned MemAlign = getKnownAlignment(II->getArgOperand(0), DL, II, AC, DT); + unsigned MemAlign = + getKnownAlignment(II->getArgOperand(0), DL, II, &AC, &DT); unsigned AlignArg = II->getNumArgOperands() - 1; ConstantInt *IntrAlign = dyn_cast<ConstantInt>(II->getArgOperand(AlignArg)); if (IntrAlign && IntrAlign->getZExtValue() < MemAlign) { @@ -2194,6 +2571,85 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::amdgcn_class: { + enum { + S_NAN = 1 << 0, // Signaling NaN + Q_NAN = 1 << 1, // Quiet NaN + N_INFINITY = 1 << 2, // Negative infinity + N_NORMAL = 1 << 3, // Negative normal + N_SUBNORMAL = 1 << 4, // Negative subnormal + N_ZERO = 1 << 5, // Negative zero + P_ZERO = 1 << 6, // Positive zero + P_SUBNORMAL = 1 << 7, // Positive subnormal + P_NORMAL = 1 << 8, // Positive normal + P_INFINITY = 1 << 9 // Positive infinity + }; + + const uint32_t FullMask = S_NAN | Q_NAN | N_INFINITY | N_NORMAL | + N_SUBNORMAL | N_ZERO | P_ZERO | P_SUBNORMAL | P_NORMAL | P_INFINITY; + + Value *Src0 = II->getArgOperand(0); + Value *Src1 = II->getArgOperand(1); + const ConstantInt *CMask = dyn_cast<ConstantInt>(Src1); + if (!CMask) { + if (isa<UndefValue>(Src0)) + return replaceInstUsesWith(*II, UndefValue::get(II->getType())); + + if (isa<UndefValue>(Src1)) + return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), false)); + break; + } + + uint32_t Mask = CMask->getZExtValue(); + + // If all tests are made, it doesn't matter what the value is. + if ((Mask & FullMask) == FullMask) + return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), true)); + + if ((Mask & FullMask) == 0) + return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), false)); + + if (Mask == (S_NAN | Q_NAN)) { + // Equivalent of isnan. Replace with standard fcmp. + Value *FCmp = Builder->CreateFCmpUNO(Src0, Src0); + FCmp->takeName(II); + return replaceInstUsesWith(*II, FCmp); + } + + const ConstantFP *CVal = dyn_cast<ConstantFP>(Src0); + if (!CVal) { + if (isa<UndefValue>(Src0)) + return replaceInstUsesWith(*II, UndefValue::get(II->getType())); + + // Clamp mask to used bits + if ((Mask & FullMask) != Mask) { + CallInst *NewCall = Builder->CreateCall(II->getCalledFunction(), + { Src0, ConstantInt::get(Src1->getType(), Mask & FullMask) } + ); + + NewCall->takeName(II); + return replaceInstUsesWith(*II, NewCall); + } + + break; + } + + const APFloat &Val = CVal->getValueAPF(); + + bool Result = + ((Mask & S_NAN) && Val.isNaN() && Val.isSignaling()) || + ((Mask & Q_NAN) && Val.isNaN() && !Val.isSignaling()) || + ((Mask & N_INFINITY) && Val.isInfinity() && Val.isNegative()) || + ((Mask & N_NORMAL) && Val.isNormal() && Val.isNegative()) || + ((Mask & N_SUBNORMAL) && Val.isDenormal() && Val.isNegative()) || + ((Mask & N_ZERO) && Val.isZero() && Val.isNegative()) || + ((Mask & P_ZERO) && Val.isZero() && !Val.isNegative()) || + ((Mask & P_SUBNORMAL) && Val.isDenormal() && !Val.isNegative()) || + ((Mask & P_NORMAL) && Val.isNormal() && !Val.isNegative()) || + ((Mask & P_INFINITY) && Val.isInfinity() && !Val.isNegative()); + + return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), Result)); + } case Intrinsic::stackrestore: { // If the save is right next to the restore, remove the restore. This can // happen when variable allocas are DCE'd. @@ -2243,6 +2699,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } case Intrinsic::lifetime_start: + // Asan needs to poison memory to detect invalid access which is possible + // even for empty lifetime range. + if (II->getFunction()->hasFnAttribute(Attribute::SanitizeAddress)) + break; + if (removeTriviallyEmptyRange(*II, Intrinsic::lifetime_start, Intrinsic::lifetime_end, *this)) return nullptr; @@ -2274,24 +2735,20 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // assume( (load addr) != null ) -> add 'nonnull' metadata to load // (if assume is valid at the load) - if (ICmpInst* ICmp = dyn_cast<ICmpInst>(IIOperand)) { - Value *LHS = ICmp->getOperand(0); - Value *RHS = ICmp->getOperand(1); - if (ICmpInst::ICMP_NE == ICmp->getPredicate() && - isa<LoadInst>(LHS) && - isa<Constant>(RHS) && - RHS->getType()->isPointerTy() && - cast<Constant>(RHS)->isNullValue()) { - LoadInst* LI = cast<LoadInst>(LHS); - if (isValidAssumeForContext(II, LI, DT)) { - MDNode *MD = MDNode::get(II->getContext(), None); - LI->setMetadata(LLVMContext::MD_nonnull, MD); - return eraseInstFromFunction(*II); - } - } + CmpInst::Predicate Pred; + Instruction *LHS; + if (match(IIOperand, m_ICmp(Pred, m_Instruction(LHS), m_Zero())) && + Pred == ICmpInst::ICMP_NE && LHS->getOpcode() == Instruction::Load && + LHS->getType()->isPointerTy() && + isValidAssumeForContext(II, LHS, &DT)) { + MDNode *MD = MDNode::get(II->getContext(), None); + LHS->setMetadata(LLVMContext::MD_nonnull, MD); + return eraseInstFromFunction(*II); + // TODO: apply nonnull return attributes to calls and invokes // TODO: apply range metadata for range check patterns? } + // If there is a dominating assume with the same condition as this one, // then this one is redundant, and should be removed. APInt KnownZero(1, 0), KnownOne(1, 0); @@ -2299,6 +2756,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (KnownOne.isAllOnesValue()) return eraseInstFromFunction(*II); + // Update the cache of affected values for this assumption (we might be + // here because we just simplified the condition). + AC.updateAffectedValues(II); break; } case Intrinsic::experimental_gc_relocate: { @@ -2329,7 +2789,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, ConstantPointerNull::get(PT)); // isKnownNonNull -> nonnull attribute - if (isKnownNonNullAt(DerivedPtr, II, DT)) + if (isKnownNonNullAt(DerivedPtr, II, &DT)) II->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); } @@ -2389,7 +2849,7 @@ Instruction *InstCombiner::tryOptimizeCall(CallInst *CI) { auto InstCombineRAUW = [this](Instruction *From, Value *With) { replaceInstUsesWith(*From, With); }; - LibCallSimplifier Simplifier(DL, TLI, InstCombineRAUW); + LibCallSimplifier Simplifier(DL, &TLI, InstCombineRAUW); if (Value *With = Simplifier.optimizeCall(CI)) { ++NumSimplified; return CI->use_empty() ? CI : replaceInstUsesWith(*CI, With); @@ -2477,8 +2937,7 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) { /// Improvements for call and invoke instructions. Instruction *InstCombiner::visitCallSite(CallSite CS) { - - if (isAllocLikeFn(CS.getInstruction(), TLI)) + if (isAllocLikeFn(CS.getInstruction(), &TLI)) return visitAllocSite(*CS.getInstruction()); bool Changed = false; @@ -2492,7 +2951,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { for (Value *V : CS.args()) { if (V->getType()->isPointerTy() && !CS.paramHasAttr(ArgNo + 1, Attribute::NonNull) && - isKnownNonNullAt(V, CS.getInstruction(), DT)) + isKnownNonNullAt(V, CS.getInstruction(), &DT)) Indices.push_back(ArgNo + 1); ArgNo++; } @@ -2613,14 +3072,14 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { /// If the callee is a constexpr cast of a function, attempt to move the cast to /// the arguments of the call/invoke. bool InstCombiner::transformConstExprCastCall(CallSite CS) { - Function *Callee = - dyn_cast<Function>(CS.getCalledValue()->stripPointerCasts()); + auto *Callee = dyn_cast<Function>(CS.getCalledValue()->stripPointerCasts()); if (!Callee) return false; - // The prototype of thunks are a lie, don't try to directly call such - // functions. + + // The prototype of a thunk is a lie. Don't directly call such a function. if (Callee->hasFnAttribute("thunk")) return false; + Instruction *Caller = CS.getInstruction(); const AttributeSet &CallerPAL = CS.getAttributes(); @@ -2842,8 +3301,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { CallInst *CI = cast<CallInst>(Caller); NC = Builder->CreateCall(Callee, Args, OpBundles); NC->takeName(CI); - if (CI->isTailCall()) - cast<CallInst>(NC)->setTailCall(); + cast<CallInst>(NC)->setTailCallKind(CI->getTailCallKind()); cast<CallInst>(NC)->setCallingConv(CI->getCallingConv()); cast<CallInst>(NC)->setAttributes(NewCallerPAL); } @@ -2966,7 +3424,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, ++Idx; ++I; - } while (1); + } while (true); } // Add any function attributes. @@ -3001,7 +3459,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, ++Idx; ++I; - } while (1); + } while (true); } // Replace the trampoline call with a direct call. Let the generic @@ -3027,10 +3485,10 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, cast<InvokeInst>(NewCaller)->setAttributes(NewPAL); } else { NewCaller = CallInst::Create(NewCallee, NewArgs, OpBundles); - if (cast<CallInst>(Caller)->isTailCall()) - cast<CallInst>(NewCaller)->setTailCall(); - cast<CallInst>(NewCaller)-> - setCallingConv(cast<CallInst>(Caller)->getCallingConv()); + cast<CallInst>(NewCaller)->setTailCallKind( + cast<CallInst>(Caller)->getTailCallKind()); + cast<CallInst>(NewCaller)->setCallingConv( + cast<CallInst>(Caller)->getCallingConv()); cast<CallInst>(NewCaller)->setAttributes(NewPAL); } diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 2055615..e74b590 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/PatternMatch.h" @@ -161,8 +162,8 @@ Value *InstCombiner::EvaluateInDifferentType(Value *V, Type *Ty, if (Constant *C = dyn_cast<Constant>(V)) { C = ConstantExpr::getIntegerCast(C, Ty, isSigned /*Sext or ZExt*/); // If we got a constantexpr back, try to simplify it with DL info. - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) - C = ConstantFoldConstantExpression(CE, DL, TLI); + if (Constant *FoldedC = ConstantFoldConstant(C, DL, &TLI)) + C = FoldedC; return C; } @@ -227,20 +228,14 @@ Value *InstCombiner::EvaluateInDifferentType(Value *V, Type *Ty, return InsertNewInstWith(Res, *I); } +Instruction::CastOps InstCombiner::isEliminableCastPair(const CastInst *CI1, + const CastInst *CI2) { + Type *SrcTy = CI1->getSrcTy(); + Type *MidTy = CI1->getDestTy(); + Type *DstTy = CI2->getDestTy(); -/// This function is a wrapper around CastInst::isEliminableCastPair. It -/// simply extracts arguments and returns what that function returns. -static Instruction::CastOps -isEliminableCastPair(const CastInst *CI, ///< First cast instruction - unsigned opcode, ///< Opcode for the second cast - Type *DstTy, ///< Target type for the second cast - const DataLayout &DL) { - Type *SrcTy = CI->getOperand(0)->getType(); // A from above - Type *MidTy = CI->getType(); // B from above - - // Get the opcodes of the two Cast instructions - Instruction::CastOps firstOp = Instruction::CastOps(CI->getOpcode()); - Instruction::CastOps secondOp = Instruction::CastOps(opcode); + Instruction::CastOps firstOp = Instruction::CastOps(CI1->getOpcode()); + Instruction::CastOps secondOp = Instruction::CastOps(CI2->getOpcode()); Type *SrcIntPtrTy = SrcTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(SrcTy) : nullptr; Type *MidIntPtrTy = @@ -260,54 +255,28 @@ isEliminableCastPair(const CastInst *CI, ///< First cast instruction return Instruction::CastOps(Res); } -/// Return true if the cast from "V to Ty" actually results in any code being -/// generated and is interesting to optimize out. -/// If the cast can be eliminated by some other simple transformation, we prefer -/// to do the simplification first. -bool InstCombiner::ShouldOptimizeCast(Instruction::CastOps opc, const Value *V, - Type *Ty) { - // Noop casts and casts of constants should be eliminated trivially. - if (V->getType() == Ty || isa<Constant>(V)) return false; - - // If this is another cast that can be eliminated, we prefer to have it - // eliminated. - if (const CastInst *CI = dyn_cast<CastInst>(V)) - if (isEliminableCastPair(CI, opc, Ty, DL)) - return false; - - // If this is a vector sext from a compare, then we don't want to break the - // idiom where each element of the extended vector is either zero or all ones. - if (opc == Instruction::SExt && isa<CmpInst>(V) && Ty->isVectorTy()) - return false; - - return true; -} - - /// @brief Implement the transforms common to all CastInst visitors. Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { Value *Src = CI.getOperand(0); - // Many cases of "cast of a cast" are eliminable. If it's eliminable we just - // eliminate it now. - if (CastInst *CSrc = dyn_cast<CastInst>(Src)) { // A->B->C cast - if (Instruction::CastOps opc = - isEliminableCastPair(CSrc, CI.getOpcode(), CI.getType(), DL)) { + // Try to eliminate a cast of a cast. + if (auto *CSrc = dyn_cast<CastInst>(Src)) { // A->B->C cast + if (Instruction::CastOps NewOpc = isEliminableCastPair(CSrc, &CI)) { // The first cast (CSrc) is eliminable so we need to fix up or replace // the second cast (CI). CSrc will then have a good chance of being dead. - return CastInst::Create(opc, CSrc->getOperand(0), CI.getType()); + return CastInst::Create(NewOpc, CSrc->getOperand(0), CI.getType()); } } - // If we are casting a select then fold the cast into the select - if (SelectInst *SI = dyn_cast<SelectInst>(Src)) + // If we are casting a select, then fold the cast into the select. + if (auto *SI = dyn_cast<SelectInst>(Src)) if (Instruction *NV = FoldOpIntoSelect(CI, SI)) return NV; - // If we are casting a PHI then fold the cast into the PHI + // If we are casting a PHI, then fold the cast into the PHI. if (isa<PHINode>(Src)) { - // We don't do this if this would create a PHI node with an illegal type if - // it is currently legal. + // Don't do this if it would create a PHI node with an illegal type from a + // legal type. if (!Src->getType()->isIntegerTy() || !CI.getType()->isIntegerTy() || ShouldChangeType(CI.getType(), Src->getType())) if (Instruction *NV = FoldOpIntoPhi(CI)) @@ -474,19 +443,39 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC, return ExtractElementInst::Create(VecInput, IC.Builder->getInt32(Elt)); } +/// Try to narrow the width of bitwise logic instructions with constants. +Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) { + Type *SrcTy = Trunc.getSrcTy(); + Type *DestTy = Trunc.getType(); + if (isa<IntegerType>(SrcTy) && !ShouldChangeType(SrcTy, DestTy)) + return nullptr; + + BinaryOperator *LogicOp; + Constant *C; + if (!match(Trunc.getOperand(0), m_OneUse(m_BinOp(LogicOp))) || + !LogicOp->isBitwiseLogicOp() || + !match(LogicOp->getOperand(1), m_Constant(C))) + return nullptr; + + // trunc (logic X, C) --> logic (trunc X, C') + Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy); + Value *NarrowOp0 = Builder->CreateTrunc(LogicOp->getOperand(0), DestTy); + return BinaryOperator::Create(LogicOp->getOpcode(), NarrowOp0, NarrowC); +} + Instruction *InstCombiner::visitTrunc(TruncInst &CI) { if (Instruction *Result = commonCastTransforms(CI)) return Result; // Test if the trunc is the user of a select which is part of a // minimum or maximum operation. If so, don't do any more simplification. - // Even simplifying demanded bits can break the canonical form of a + // Even simplifying demanded bits can break the canonical form of a // min/max. Value *LHS, *RHS; if (SelectInst *SI = dyn_cast<SelectInst>(CI.getOperand(0))) if (matchSelectPattern(SI, LHS, RHS).Flavor != SPF_UNKNOWN) return nullptr; - + // See if we can simplify any instructions used by the input whose sole // purpose is to compute bits we don't care about. if (SimplifyDemandedInstructionBits(CI)) @@ -562,14 +551,26 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { } } - // Transform "trunc (and X, cst)" -> "and (trunc X), cst" so long as the dest - // type isn't non-native. + if (Instruction *I = shrinkBitwiseLogic(CI)) + return I; + if (Src->hasOneUse() && isa<IntegerType>(SrcTy) && - ShouldChangeType(SrcTy, DestTy) && - match(Src, m_And(m_Value(A), m_ConstantInt(Cst)))) { - Value *NewTrunc = Builder->CreateTrunc(A, DestTy, A->getName() + ".tr"); - return BinaryOperator::CreateAnd(NewTrunc, - ConstantExpr::getTrunc(Cst, DestTy)); + ShouldChangeType(SrcTy, DestTy)) { + // Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the + // dest type is native and cst < dest size. + if (match(Src, m_Shl(m_Value(A), m_ConstantInt(Cst))) && + !match(A, m_Shr(m_Value(), m_Constant()))) { + // Skip shifts of shift by constants. It undoes a combine in + // FoldShiftByConstant and is the extend in reg pattern. + const unsigned DestSize = DestTy->getScalarSizeInBits(); + if (Cst->getValue().ult(DestSize)) { + Value *NewTrunc = Builder->CreateTrunc(A, DestTy, A->getName() + ".tr"); + + return BinaryOperator::Create( + Instruction::Shl, NewTrunc, + ConstantInt::get(DestTy, Cst->getValue().trunc(DestSize))); + } + } } if (Instruction *I = foldVecTruncToExtElt(CI, *this, DL)) @@ -578,10 +579,8 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { return nullptr; } -/// Transform (zext icmp) to bitwise / integer operations in order to eliminate -/// the icmp. -Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, - bool DoXform) { +Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, + bool DoTransform) { // If we are just checking for a icmp eq of a single bit and zext'ing it // to an integer, then shift the bit to the appropriate place and then // cast to integer to avoid the comparison. @@ -592,7 +591,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, // zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear. if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV == 0) || (ICI->getPredicate() == ICmpInst::ICMP_SGT && Op1CV.isAllOnesValue())) { - if (!DoXform) return ICI; + if (!DoTransform) return ICI; Value *In = ICI->getOperand(0); Value *Sh = ConstantInt::get(In->getType(), @@ -627,7 +626,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, APInt KnownZeroMask(~KnownZero); if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? - if (!DoXform) return ICI; + if (!DoTransform) return ICI; bool isNE = ICI->getPredicate() == ICmpInst::ICMP_NE; if (Op1CV != 0 && (Op1CV != KnownZeroMask)) { @@ -655,7 +654,9 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, if (CI.getType() == In->getType()) return replaceInstUsesWith(CI, In); - return CastInst::CreateIntegerCast(In, CI.getType(), false/*ZExt*/); + + Value *IntCast = Builder->CreateIntCast(In, CI.getType(), false); + return replaceInstUsesWith(CI, IntCast); } } } @@ -678,7 +679,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, APInt KnownBits = KnownZeroLHS | KnownOneLHS; APInt UnknownBit = ~KnownBits; if (UnknownBit.countPopulation() == 1) { - if (!DoXform) return ICI; + if (!DoTransform) return ICI; Value *Result = Builder->CreateXor(LHS, RHS); @@ -760,9 +761,7 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, // If the operation is an AND/OR/XOR and the bits to clear are zero in the // other side, BitsToClear is ok. - if (Tmp == 0 && - (Opc == Instruction::And || Opc == Instruction::Or || - Opc == Instruction::Xor)) { + if (Tmp == 0 && I->isBitwiseLogicOp()) { // We use MaskedValueIsZero here for generality, but the case we care // about the most is constant RHS. unsigned VSize = V->getType()->getScalarSizeInBits(); @@ -922,16 +921,26 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { BinaryOperator *SrcI = dyn_cast<BinaryOperator>(Src); if (SrcI && SrcI->getOpcode() == Instruction::Or) { - // zext (or icmp, icmp) --> or (zext icmp), (zext icmp) if at least one - // of the (zext icmp) will be transformed. + // zext (or icmp, icmp) -> or (zext icmp), (zext icmp) if at least one + // of the (zext icmp) can be eliminated. If so, immediately perform the + // according elimination. ICmpInst *LHS = dyn_cast<ICmpInst>(SrcI->getOperand(0)); ICmpInst *RHS = dyn_cast<ICmpInst>(SrcI->getOperand(1)); if (LHS && RHS && LHS->hasOneUse() && RHS->hasOneUse() && (transformZExtICmp(LHS, CI, false) || transformZExtICmp(RHS, CI, false))) { + // zext (or icmp, icmp) -> or (zext icmp), (zext icmp) Value *LCast = Builder->CreateZExt(LHS, CI.getType(), LHS->getName()); Value *RCast = Builder->CreateZExt(RHS, CI.getType(), RHS->getName()); - return BinaryOperator::Create(Instruction::Or, LCast, RCast); + BinaryOperator *Or = BinaryOperator::Create(Instruction::Or, LCast, RCast); + + // Perform the elimination. + if (auto *LZExt = dyn_cast<ZExtInst>(LCast)) + transformZExtICmp(LHS, *LZExt); + if (auto *RZExt = dyn_cast<ZExtInst>(RCast)) + transformZExtICmp(RHS, *RZExt); + + return Or; } } @@ -952,14 +961,6 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { return BinaryOperator::CreateXor(Builder->CreateAnd(X, ZC), ZC); } - // zext (xor i1 X, true) to i32 --> xor (zext i1 X to i32), 1 - if (SrcI && SrcI->hasOneUse() && - SrcI->getType()->getScalarType()->isIntegerTy(1) && - match(SrcI, m_Not(m_Value(X))) && (!X->hasOneUse() || !isa<CmpInst>(X))) { - Value *New = Builder->CreateZExt(X, CI.getType()); - return BinaryOperator::CreateXor(New, ConstantInt::get(CI.getType(), 1)); - } - return nullptr; } @@ -1132,7 +1133,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { Type *SrcTy = Src->getType(), *DestTy = CI.getType(); // If we know that the value being extended is positive, we can use a zext - // instead. + // instead. bool KnownZero, KnownOne; ComputeSignBit(Src, KnownZero, KnownOne, 0, &CI); if (KnownZero) { @@ -1238,14 +1239,14 @@ static Value *lookThroughFPExtensions(Value *V) { if (CFP->getType() == Type::getPPC_FP128Ty(V->getContext())) return V; // No constant folding of this. // See if the value can be truncated to half and then reextended. - if (Value *V = fitsInFPType(CFP, APFloat::IEEEhalf)) + if (Value *V = fitsInFPType(CFP, APFloat::IEEEhalf())) return V; // See if the value can be truncated to float and then reextended. - if (Value *V = fitsInFPType(CFP, APFloat::IEEEsingle)) + if (Value *V = fitsInFPType(CFP, APFloat::IEEEsingle())) return V; if (CFP->getType()->isDoubleTy()) return V; // Won't shrink. - if (Value *V = fitsInFPType(CFP, APFloat::IEEEdouble)) + if (Value *V = fitsInFPType(CFP, APFloat::IEEEdouble())) return V; // Don't try to shrink to various long double types. } @@ -1789,6 +1790,205 @@ static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast, return ExtractElementInst::Create(NewBC, ExtElt->getIndexOperand()); } +/// Change the type of a bitwise logic operation if we can eliminate a bitcast. +static Instruction *foldBitCastBitwiseLogic(BitCastInst &BitCast, + InstCombiner::BuilderTy &Builder) { + Type *DestTy = BitCast.getType(); + BinaryOperator *BO; + if (!DestTy->getScalarType()->isIntegerTy() || + !match(BitCast.getOperand(0), m_OneUse(m_BinOp(BO))) || + !BO->isBitwiseLogicOp()) + return nullptr; + + // FIXME: This transform is restricted to vector types to avoid backend + // problems caused by creating potentially illegal operations. If a fix-up is + // added to handle that situation, we can remove this check. + if (!DestTy->isVectorTy() || !BO->getType()->isVectorTy()) + return nullptr; + + Value *X; + if (match(BO->getOperand(0), m_OneUse(m_BitCast(m_Value(X)))) && + X->getType() == DestTy && !isa<Constant>(X)) { + // bitcast(logic(bitcast(X), Y)) --> logic'(X, bitcast(Y)) + Value *CastedOp1 = Builder.CreateBitCast(BO->getOperand(1), DestTy); + return BinaryOperator::Create(BO->getOpcode(), X, CastedOp1); + } + + if (match(BO->getOperand(1), m_OneUse(m_BitCast(m_Value(X)))) && + X->getType() == DestTy && !isa<Constant>(X)) { + // bitcast(logic(Y, bitcast(X))) --> logic'(bitcast(Y), X) + Value *CastedOp0 = Builder.CreateBitCast(BO->getOperand(0), DestTy); + return BinaryOperator::Create(BO->getOpcode(), CastedOp0, X); + } + + return nullptr; +} + +/// Change the type of a select if we can eliminate a bitcast. +static Instruction *foldBitCastSelect(BitCastInst &BitCast, + InstCombiner::BuilderTy &Builder) { + Value *Cond, *TVal, *FVal; + if (!match(BitCast.getOperand(0), + m_OneUse(m_Select(m_Value(Cond), m_Value(TVal), m_Value(FVal))))) + return nullptr; + + // A vector select must maintain the same number of elements in its operands. + Type *CondTy = Cond->getType(); + Type *DestTy = BitCast.getType(); + if (CondTy->isVectorTy()) { + if (!DestTy->isVectorTy()) + return nullptr; + if (DestTy->getVectorNumElements() != CondTy->getVectorNumElements()) + return nullptr; + } + + // FIXME: This transform is restricted from changing the select between + // scalars and vectors to avoid backend problems caused by creating + // potentially illegal operations. If a fix-up is added to handle that + // situation, we can remove this check. + if (DestTy->isVectorTy() != TVal->getType()->isVectorTy()) + return nullptr; + + auto *Sel = cast<Instruction>(BitCast.getOperand(0)); + Value *X; + if (match(TVal, m_OneUse(m_BitCast(m_Value(X)))) && X->getType() == DestTy && + !isa<Constant>(X)) { + // bitcast(select(Cond, bitcast(X), Y)) --> select'(Cond, X, bitcast(Y)) + Value *CastedVal = Builder.CreateBitCast(FVal, DestTy); + return SelectInst::Create(Cond, X, CastedVal, "", nullptr, Sel); + } + + if (match(FVal, m_OneUse(m_BitCast(m_Value(X)))) && X->getType() == DestTy && + !isa<Constant>(X)) { + // bitcast(select(Cond, Y, bitcast(X))) --> select'(Cond, bitcast(Y), X) + Value *CastedVal = Builder.CreateBitCast(TVal, DestTy); + return SelectInst::Create(Cond, CastedVal, X, "", nullptr, Sel); + } + + return nullptr; +} + +/// Check if all users of CI are StoreInsts. +static bool hasStoreUsersOnly(CastInst &CI) { + for (User *U : CI.users()) { + if (!isa<StoreInst>(U)) + return false; + } + return true; +} + +/// This function handles following case +/// +/// A -> B cast +/// PHI +/// B -> A cast +/// +/// All the related PHI nodes can be replaced by new PHI nodes with type A. +/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. +Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { + // BitCast used by Store can be handled in InstCombineLoadStoreAlloca.cpp. + if (hasStoreUsersOnly(CI)) + return nullptr; + + Value *Src = CI.getOperand(0); + Type *SrcTy = Src->getType(); // Type B + Type *DestTy = CI.getType(); // Type A + + SmallVector<PHINode *, 4> PhiWorklist; + SmallSetVector<PHINode *, 4> OldPhiNodes; + + // Find all of the A->B casts and PHI nodes. + // We need to inpect all related PHI nodes, but PHIs can be cyclic, so + // OldPhiNodes is used to track all known PHI nodes, before adding a new + // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. + PhiWorklist.push_back(PN); + OldPhiNodes.insert(PN); + while (!PhiWorklist.empty()) { + auto *OldPN = PhiWorklist.pop_back_val(); + for (Value *IncValue : OldPN->incoming_values()) { + if (isa<Constant>(IncValue)) + continue; + + if (auto *LI = dyn_cast<LoadInst>(IncValue)) { + // If there is a sequence of one or more load instructions, each loaded + // value is used as address of later load instruction, bitcast is + // necessary to change the value type, don't optimize it. For + // simplicity we give up if the load address comes from another load. + Value *Addr = LI->getOperand(0); + if (Addr == &CI || isa<LoadInst>(Addr)) + return nullptr; + if (LI->hasOneUse() && LI->isSimple()) + continue; + // If a LoadInst has more than one use, changing the type of loaded + // value may create another bitcast. + return nullptr; + } + + if (auto *PNode = dyn_cast<PHINode>(IncValue)) { + if (OldPhiNodes.insert(PNode)) + PhiWorklist.push_back(PNode); + continue; + } + + auto *BCI = dyn_cast<BitCastInst>(IncValue); + // We can't handle other instructions. + if (!BCI) + return nullptr; + + // Verify it's a A->B cast. + Type *TyA = BCI->getOperand(0)->getType(); + Type *TyB = BCI->getType(); + if (TyA != DestTy || TyB != SrcTy) + return nullptr; + } + } + + // For each old PHI node, create a corresponding new PHI node with a type A. + SmallDenseMap<PHINode *, PHINode *> NewPNodes; + for (auto *OldPN : OldPhiNodes) { + Builder->SetInsertPoint(OldPN); + PHINode *NewPN = Builder->CreatePHI(DestTy, OldPN->getNumOperands()); + NewPNodes[OldPN] = NewPN; + } + + // Fill in the operands of new PHI nodes. + for (auto *OldPN : OldPhiNodes) { + PHINode *NewPN = NewPNodes[OldPN]; + for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { + Value *V = OldPN->getOperand(j); + Value *NewV = nullptr; + if (auto *C = dyn_cast<Constant>(V)) { + NewV = ConstantExpr::getBitCast(C, DestTy); + } else if (auto *LI = dyn_cast<LoadInst>(V)) { + Builder->SetInsertPoint(LI->getNextNode()); + NewV = Builder->CreateBitCast(LI, DestTy); + Worklist.Add(LI); + } else if (auto *BCI = dyn_cast<BitCastInst>(V)) { + NewV = BCI->getOperand(0); + } else if (auto *PrevPN = dyn_cast<PHINode>(V)) { + NewV = NewPNodes[PrevPN]; + } + assert(NewV); + NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j)); + } + } + + // If there is a store with type B, change it to type A. + for (User *U : PN->users()) { + auto *SI = dyn_cast<StoreInst>(U); + if (SI && SI->isSimple() && SI->getOperand(0) == PN) { + Builder->SetInsertPoint(SI); + auto *NewBC = + cast<BitCastInst>(Builder->CreateBitCast(NewPNodes[PN], SrcTy)); + SI->setOperand(0, NewBC); + Worklist.Add(SI); + assert(hasStoreUsersOnly(*NewBC)); + } + } + + return replaceInstUsesWith(CI, NewPNodes[PN]); +} + Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // If the operands are integer typed then apply the integer transforms, // otherwise just apply the common ones. @@ -1912,9 +2112,20 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { } } + // Handle the A->B->A cast, and there is an intervening PHI node. + if (PHINode *PN = dyn_cast<PHINode>(Src)) + if (Instruction *I = optimizeBitCastFromPhi(CI, PN)) + return I; + if (Instruction *I = canonicalizeBitCastExtElt(CI, *this, DL)) return I; + if (Instruction *I = foldBitCastBitwiseLogic(CI, *Builder)) + return I; + + if (Instruction *I = foldBitCastSelect(CI, *Builder)) + return I; + if (SrcTy->isPointerTy()) return commonPointerCastTransforms(CI); return commonCastTransforms(CI); diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 961497f..428f94b 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -35,17 +35,12 @@ using namespace PatternMatch; // How many times is a select replaced by one of its operands? STATISTIC(NumSel, "Number of select opts"); -// Initialization Routines -static ConstantInt *getOne(Constant *C) { - return ConstantInt::get(cast<IntegerType>(C->getType()), 1); -} - -static ConstantInt *ExtractElement(Constant *V, Constant *Idx) { +static ConstantInt *extractElement(Constant *V, Constant *Idx) { return cast<ConstantInt>(ConstantExpr::getExtractElement(V, Idx)); } -static bool HasAddOverflow(ConstantInt *Result, +static bool hasAddOverflow(ConstantInt *Result, ConstantInt *In1, ConstantInt *In2, bool IsSigned) { if (!IsSigned) @@ -58,28 +53,28 @@ static bool HasAddOverflow(ConstantInt *Result, /// Compute Result = In1+In2, returning true if the result overflowed for this /// type. -static bool AddWithOverflow(Constant *&Result, Constant *In1, +static bool addWithOverflow(Constant *&Result, Constant *In1, Constant *In2, bool IsSigned = false) { Result = ConstantExpr::getAdd(In1, In2); if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) { for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i); - if (HasAddOverflow(ExtractElement(Result, Idx), - ExtractElement(In1, Idx), - ExtractElement(In2, Idx), + if (hasAddOverflow(extractElement(Result, Idx), + extractElement(In1, Idx), + extractElement(In2, Idx), IsSigned)) return true; } return false; } - return HasAddOverflow(cast<ConstantInt>(Result), + return hasAddOverflow(cast<ConstantInt>(Result), cast<ConstantInt>(In1), cast<ConstantInt>(In2), IsSigned); } -static bool HasSubOverflow(ConstantInt *Result, +static bool hasSubOverflow(ConstantInt *Result, ConstantInt *In1, ConstantInt *In2, bool IsSigned) { if (!IsSigned) @@ -93,23 +88,23 @@ static bool HasSubOverflow(ConstantInt *Result, /// Compute Result = In1-In2, returning true if the result overflowed for this /// type. -static bool SubWithOverflow(Constant *&Result, Constant *In1, +static bool subWithOverflow(Constant *&Result, Constant *In1, Constant *In2, bool IsSigned = false) { Result = ConstantExpr::getSub(In1, In2); if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) { for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i); - if (HasSubOverflow(ExtractElement(Result, Idx), - ExtractElement(In1, Idx), - ExtractElement(In2, Idx), + if (hasSubOverflow(extractElement(Result, Idx), + extractElement(In1, Idx), + extractElement(In2, Idx), IsSigned)) return true; } return false; } - return HasSubOverflow(cast<ConstantInt>(Result), + return hasSubOverflow(cast<ConstantInt>(Result), cast<ConstantInt>(In1), cast<ConstantInt>(In2), IsSigned); } @@ -126,26 +121,26 @@ static bool isBranchOnSignBitCheck(ICmpInst &I, bool isSignBit) { /// Given an exploded icmp instruction, return true if the comparison only /// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if the /// result of the comparison is true when the input value is signed. -static bool isSignBitCheck(ICmpInst::Predicate Pred, ConstantInt *RHS, +static bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS, bool &TrueIfSigned) { switch (Pred) { case ICmpInst::ICMP_SLT: // True if LHS s< 0 TrueIfSigned = true; - return RHS->isZero(); + return RHS == 0; case ICmpInst::ICMP_SLE: // True if LHS s<= RHS and RHS == -1 TrueIfSigned = true; - return RHS->isAllOnesValue(); + return RHS.isAllOnesValue(); case ICmpInst::ICMP_SGT: // True if LHS s> -1 TrueIfSigned = false; - return RHS->isAllOnesValue(); + return RHS.isAllOnesValue(); case ICmpInst::ICMP_UGT: // True if LHS u> RHS and RHS == high-bit-mask - 1 TrueIfSigned = true; - return RHS->isMaxValue(true); + return RHS.isMaxSignedValue(); case ICmpInst::ICMP_UGE: // True if LHS u>= RHS and RHS == high-bit-mask (2^7, 2^15, 2^31, etc) TrueIfSigned = true; - return RHS->getValue().isSignBit(); + return RHS.isSignBit(); default: return false; } @@ -154,19 +149,20 @@ static bool isSignBitCheck(ICmpInst::Predicate Pred, ConstantInt *RHS, /// Returns true if the exploded icmp can be expressed as a signed comparison /// to zero and updates the predicate accordingly. /// The signedness of the comparison is preserved. -static bool isSignTest(ICmpInst::Predicate &Pred, const ConstantInt *RHS) { +/// TODO: Refactor with decomposeBitTestICmp()? +static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) { if (!ICmpInst::isSigned(Pred)) return false; - if (RHS->isZero()) + if (C == 0) return ICmpInst::isRelational(Pred); - if (RHS->isOne()) { + if (C == 1) { if (Pred == ICmpInst::ICMP_SLT) { Pred = ICmpInst::ICMP_SLE; return true; } - } else if (RHS->isAllOnesValue()) { + } else if (C.isAllOnesValue()) { if (Pred == ICmpInst::ICMP_SGT) { Pred = ICmpInst::ICMP_SGE; return true; @@ -176,16 +172,10 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const ConstantInt *RHS) { return false; } -/// Return true if the constant is of the form 1+0+. This is the same as -/// lowones(~X). -static bool isHighOnes(const ConstantInt *CI) { - return (~CI->getValue() + 1).isPowerOf2(); -} - /// Given a signed integer type and a set of known zero and one bits, compute /// the maximum and minimum values that could have the specified known zero and /// known one bits, returning them in Min/Max. -static void ComputeSignedMinMaxValuesFromKnownBits(const APInt &KnownZero, +static void computeSignedMinMaxValuesFromKnownBits(const APInt &KnownZero, const APInt &KnownOne, APInt &Min, APInt &Max) { assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() && @@ -208,7 +198,7 @@ static void ComputeSignedMinMaxValuesFromKnownBits(const APInt &KnownZero, /// Given an unsigned integer type and a set of known zero and one bits, compute /// the maximum and minimum values that could have the specified known zero and /// known one bits, returning them in Min/Max. -static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, +static void computeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, const APInt &KnownOne, APInt &Min, APInt &Max) { assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() && @@ -231,9 +221,10 @@ static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero, /// /// If AndCst is non-null, then the loaded value is masked with that constant /// before doing the comparison. This handles cases like "A[i]&4 == 0". -Instruction *InstCombiner:: -FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, - CmpInst &ICI, ConstantInt *AndCst) { +Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, + GlobalVariable *GV, + CmpInst &ICI, + ConstantInt *AndCst) { Constant *Init = GV->getInitializer(); if (!isa<ConstantArray>(Init) && !isa<ConstantDataArray>(Init)) return nullptr; @@ -319,7 +310,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, // Find out if the comparison would be true or false for the i'th element. Constant *C = ConstantFoldCompareInstOperands(ICI.getPredicate(), Elt, - CompareRHS, DL, TLI); + CompareRHS, DL, &TLI); // If the result is undef for this element, ignore it. if (isa<UndefValue>(C)) { // Extend range state machines to cover this element in case there is an @@ -509,7 +500,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV, /// /// If we can't emit an optimized form for this expression, this returns null. /// -static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, +static Value *evaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, const DataLayout &DL) { gep_type_iterator GTI = gep_type_begin(GEP); @@ -526,7 +517,7 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, if (CI->isZero()) continue; // Handle a struct index, which adds its field offset to the pointer. - if (StructType *STy = dyn_cast<StructType>(*GTI)) { + if (StructType *STy = GTI.getStructTypeOrNull()) { Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); } else { uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); @@ -556,7 +547,7 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, if (CI->isZero()) continue; // Handle a struct index, which adds its field offset to the pointer. - if (StructType *STy = dyn_cast<StructType>(*GTI)) { + if (StructType *STy = GTI.getStructTypeOrNull()) { Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); } else { uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); @@ -893,6 +884,10 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, if (!GEPLHS->hasAllConstantIndices()) return nullptr; + // Make sure the pointers have the same type. + if (GEPLHS->getType() != RHS->getType()) + return nullptr; + Value *PtrBase, *Index; std::tie(PtrBase, Index) = getAsConstantIndexedAddress(GEPLHS, DL); @@ -919,7 +914,7 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, /// Fold comparisons between a GEP instruction and something else. At this point /// we know that the GEP is on the LHS of the comparison. -Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, +Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, ICmpInst::Predicate Cond, Instruction &I) { // Don't transform signed compares of GEPs into index compares. Even if the @@ -941,7 +936,7 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // This transformation (ignoring the base and scales) is valid because we // know pointers can't overflow since the gep is inbounds. See if we can // output an optimized form. - Value *Offset = EvaluateGEPOffsetExpression(GEPLHS, *this, DL); + Value *Offset = evaluateGEPOffsetExpression(GEPLHS, *this, DL); // If not, synthesize the offset the hard way. if (!Offset) @@ -1003,12 +998,12 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // If one of the GEPs has all zero indices, recurse. if (GEPLHS->hasAllZeroIndices()) - return FoldGEPICmp(GEPRHS, GEPLHS->getOperand(0), + return foldGEPICmp(GEPRHS, GEPLHS->getOperand(0), ICmpInst::getSwappedPredicate(Cond), I); // If the other GEP has all zero indices, recurse. if (GEPRHS->hasAllZeroIndices()) - return FoldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); + return foldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds(); if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands()) { @@ -1056,8 +1051,9 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, return transformToIndexedCompare(GEPLHS, RHS, Cond, DL); } -Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, - Value *Other) { +Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI, + const AllocaInst *Alloca, + const Value *Other) { assert(ICI.isEquality() && "Cannot fold non-equality comparison."); // It would be tempting to fold away comparisons between allocas and any @@ -1076,8 +1072,8 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, unsigned MaxIter = 32; // Break cycles and bound to constant-time. - SmallVector<Use *, 32> Worklist; - for (Use &U : Alloca->uses()) { + SmallVector<const Use *, 32> Worklist; + for (const Use &U : Alloca->uses()) { if (Worklist.size() >= MaxIter) return nullptr; Worklist.push_back(&U); @@ -1086,8 +1082,8 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, unsigned NumCmps = 0; while (!Worklist.empty()) { assert(Worklist.size() <= MaxIter); - Use *U = Worklist.pop_back_val(); - Value *V = U->getUser(); + const Use *U = Worklist.pop_back_val(); + const Value *V = U->getUser(); --MaxIter; if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V) || isa<PHINode>(V) || @@ -1096,7 +1092,7 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, } else if (isa<LoadInst>(V)) { // Loading from the pointer doesn't escape it. continue; - } else if (auto *SI = dyn_cast<StoreInst>(V)) { + } else if (const auto *SI = dyn_cast<StoreInst>(V)) { // Storing *to* the pointer is fine, but storing the pointer escapes it. if (SI->getValueOperand() == U->get()) return nullptr; @@ -1105,7 +1101,7 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, if (NumCmps++) return nullptr; // Found more than one cmp. continue; - } else if (auto *Intrin = dyn_cast<IntrinsicInst>(V)) { + } else if (const auto *Intrin = dyn_cast<IntrinsicInst>(V)) { switch (Intrin->getIntrinsicID()) { // These intrinsics don't escape or compare the pointer. Memset is safe // because we don't allow ptrtoint. Memcpy and memmove are safe because @@ -1120,7 +1116,7 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, } else { return nullptr; } - for (Use &U : V->uses()) { + for (const Use &U : V->uses()) { if (Worklist.size() >= MaxIter) return nullptr; Worklist.push_back(&U); @@ -1134,9 +1130,9 @@ Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, } /// Fold "icmp pred (X+CI), X". -Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI, - Value *X, ConstantInt *CI, - ICmpInst::Predicate Pred) { +Instruction *InstCombiner::foldICmpAddOpConst(Instruction &ICI, + Value *X, ConstantInt *CI, + ICmpInst::Predicate Pred) { // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, // so the values can never be equal. Similarly for all other "or equals" // operators. @@ -1181,52 +1177,995 @@ Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI, return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantExpr::getSub(SMax, C)); } -/// Fold "icmp pred, ([su]div X, DivRHS), CmpRHS" where DivRHS and CmpRHS are -/// both known to be integer constants. -Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, - ConstantInt *DivRHS) { - ConstantInt *CmpRHS = cast<ConstantInt>(ICI.getOperand(1)); - const APInt &CmpRHSV = CmpRHS->getValue(); +/// Handle "(icmp eq/ne (ashr/lshr AP2, A), AP1)" -> +/// (icmp eq/ne A, Log2(AP2/AP1)) -> +/// (icmp eq/ne A, Log2(AP2) - Log2(AP1)). +Instruction *InstCombiner::foldICmpShrConstConst(ICmpInst &I, Value *A, + const APInt &AP1, + const APInt &AP2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2 == 0) + return nullptr; + + bool IsAShr = isa<AShrOperator>(I.getOperand(0)); + if (IsAShr) { + if (AP2.isAllOnesValue()) + return nullptr; + if (AP2.isNegative() != AP1.isNegative()) + return nullptr; + if (AP2.sgt(AP1)) + return nullptr; + } + + if (!AP1) + // 'A' must be large enough to shift out the highest set bit. + return getICmp(I.ICMP_UGT, A, + ConstantInt::get(A->getType(), AP2.logBase2())); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + int Shift; + if (IsAShr && AP1.isNegative()) + Shift = AP1.countLeadingOnes() - AP2.countLeadingOnes(); + else + Shift = AP1.countLeadingZeros() - AP2.countLeadingZeros(); + + if (Shift > 0) { + if (IsAShr && AP1 == AP2.ashr(Shift)) { + // There are multiple solutions if we are comparing against -1 and the LHS + // of the ashr is not a power of two. + if (AP1.isAllOnesValue() && !AP2.isPowerOf2()) + return getICmp(I.ICMP_UGE, A, ConstantInt::get(A->getType(), Shift)); + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + } else if (AP1 == AP2.lshr(Shift)) { + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + } + } + + // Shifting const2 will never be equal to const1. + // FIXME: This should always be handled by InstSimplify? + auto *TorF = ConstantInt::get(I.getType(), I.getPredicate() == I.ICMP_NE); + return replaceInstUsesWith(I, TorF); +} + +/// Handle "(icmp eq/ne (shl AP2, A), AP1)" -> +/// (icmp eq/ne A, TrailingZeros(AP1) - TrailingZeros(AP2)). +Instruction *InstCombiner::foldICmpShlConstConst(ICmpInst &I, Value *A, + const APInt &AP1, + const APInt &AP2) { + assert(I.isEquality() && "Cannot fold icmp gt/lt"); + + auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + if (I.getPredicate() == I.ICMP_NE) + Pred = CmpInst::getInversePredicate(Pred); + return new ICmpInst(Pred, LHS, RHS); + }; + + // Don't bother doing any work for cases which InstSimplify handles. + if (AP2 == 0) + return nullptr; + + unsigned AP2TrailingZeros = AP2.countTrailingZeros(); + + if (!AP1 && AP2TrailingZeros != 0) + return getICmp( + I.ICMP_UGE, A, + ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros)); + + if (AP1 == AP2) + return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + + // Get the distance between the lowest bits that are set. + int Shift = AP1.countTrailingZeros() - AP2TrailingZeros; + + if (Shift > 0 && AP2.shl(Shift) == AP1) + return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + + // Shifting const2 will never be equal to const1. + // FIXME: This should always be handled by InstSimplify? + auto *TorF = ConstantInt::get(I.getType(), I.getPredicate() == I.ICMP_NE); + return replaceInstUsesWith(I, TorF); +} + +/// The caller has matched a pattern of the form: +/// I = icmp ugt (add (add A, B), CI2), CI1 +/// If this is of the form: +/// sum = a + b +/// if (sum+128 >u 255) +/// Then replace it with llvm.sadd.with.overflow.i8. +/// +static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, + ConstantInt *CI2, ConstantInt *CI1, + InstCombiner &IC) { + // The transformation we're trying to do here is to transform this into an + // llvm.sadd.with.overflow. To do this, we have to replace the original add + // with a narrower add, and discard the add-with-constant that is part of the + // range check (if we can't eliminate it, this isn't profitable). + + // In order to eliminate the add-with-constant, the compare can be its only + // use. + Instruction *AddWithCst = cast<Instruction>(I.getOperand(0)); + if (!AddWithCst->hasOneUse()) + return nullptr; + + // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. + if (!CI2->getValue().isPowerOf2()) + return nullptr; + unsigned NewWidth = CI2->getValue().countTrailingZeros(); + if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) + return nullptr; + + // The width of the new add formed is 1 more than the bias. + ++NewWidth; + + // Check to see that CI1 is an all-ones value with NewWidth bits. + if (CI1->getBitWidth() == NewWidth || + CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth)) + return nullptr; + + // This is only really a signed overflow check if the inputs have been + // sign-extended; check for that condition. For example, if CI2 is 2^31 and + // the operands of the add are 64 bits wide, we need at least 33 sign bits. + unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; + if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || + IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) + return nullptr; + + // In order to replace the original add with a narrower + // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant + // and truncates that discard the high bits of the add. Verify that this is + // the case. + Instruction *OrigAdd = cast<Instruction>(AddWithCst->getOperand(0)); + for (User *U : OrigAdd->users()) { + if (U == AddWithCst) + continue; + + // Only accept truncates for now. We would really like a nice recursive + // predicate like SimplifyDemandedBits, but which goes downwards the use-def + // chain to see which bits of a value are actually demanded. If the + // original add had another add which was then immediately truncated, we + // could still do the transformation. + TruncInst *TI = dyn_cast<TruncInst>(U); + if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth) + return nullptr; + } + + // If the pattern matches, truncate the inputs to the narrower type and + // use the sadd_with_overflow intrinsic to efficiently compute both the + // result and the overflow bit. + Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth); + Value *F = Intrinsic::getDeclaration(I.getModule(), + Intrinsic::sadd_with_overflow, NewType); + + InstCombiner::BuilderTy *Builder = IC.Builder; + + // Put the new code above the original add, in case there are any uses of the + // add between the add and the compare. + Builder->SetInsertPoint(OrigAdd); + + Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName() + ".trunc"); + Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName() + ".trunc"); + CallInst *Call = Builder->CreateCall(F, {TruncA, TruncB}, "sadd"); + Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result"); + Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType()); + + // The inner add was the result of the narrow add, zero extended to the + // wider type. Replace it with the result computed by the intrinsic. + IC.replaceInstUsesWith(*OrigAdd, ZExt); + + // The original icmp gets replaced with the overflow value. + return ExtractValueInst::Create(Call, 1, "sadd.overflow"); +} + +// Fold icmp Pred X, C. +Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { + CmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Cmp.getOperand(0); + + const APInt *C; + if (!match(Cmp.getOperand(1), m_APInt(C))) + return nullptr; + + Value *A = nullptr, *B = nullptr; + + // Match the following pattern, which is a common idiom when writing + // overflow-safe integer arithmetic functions. The source performs an addition + // in wider type and explicitly checks for overflow using comparisons against + // INT_MIN and INT_MAX. Simplify by using the sadd_with_overflow intrinsic. + // + // TODO: This could probably be generalized to handle other overflow-safe + // operations if we worked out the formulas to compute the appropriate magic + // constants. + // + // sum = a + b + // if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8 + { + ConstantInt *CI2; // I = icmp ugt (add (add A, B), CI2), CI + if (Pred == ICmpInst::ICMP_UGT && + match(X, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2)))) + if (Instruction *Res = processUGT_ADDCST_ADD( + Cmp, A, B, CI2, cast<ConstantInt>(Cmp.getOperand(1)), *this)) + return Res; + } + + // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) + if (*C == 0 && Pred == ICmpInst::ICMP_SGT) { + SelectPatternResult SPR = matchSelectPattern(X, A, B); + if (SPR.Flavor == SPF_SMIN) { + if (isKnownPositive(A, DL)) + return new ICmpInst(Pred, B, Cmp.getOperand(1)); + if (isKnownPositive(B, DL)) + return new ICmpInst(Pred, A, Cmp.getOperand(1)); + } + } + + // FIXME: Use m_APInt to allow folds for splat constants. + ConstantInt *CI = dyn_cast<ConstantInt>(Cmp.getOperand(1)); + if (!CI) + return nullptr; + + // Canonicalize icmp instructions based on dominating conditions. + BasicBlock *Parent = Cmp.getParent(); + BasicBlock *Dom = Parent->getSinglePredecessor(); + auto *BI = Dom ? dyn_cast<BranchInst>(Dom->getTerminator()) : nullptr; + ICmpInst::Predicate Pred2; + BasicBlock *TrueBB, *FalseBB; + ConstantInt *CI2; + if (BI && match(BI, m_Br(m_ICmp(Pred2, m_Specific(X), m_ConstantInt(CI2)), + TrueBB, FalseBB)) && + TrueBB != FalseBB) { + ConstantRange CR = + ConstantRange::makeAllowedICmpRegion(Pred, CI->getValue()); + ConstantRange DominatingCR = + (Parent == TrueBB) + ? ConstantRange::makeExactICmpRegion(Pred2, CI2->getValue()) + : ConstantRange::makeExactICmpRegion( + CmpInst::getInversePredicate(Pred2), CI2->getValue()); + ConstantRange Intersection = DominatingCR.intersectWith(CR); + ConstantRange Difference = DominatingCR.difference(CR); + if (Intersection.isEmptySet()) + return replaceInstUsesWith(Cmp, Builder->getFalse()); + if (Difference.isEmptySet()) + return replaceInstUsesWith(Cmp, Builder->getTrue()); + + // If this is a normal comparison, it demands all bits. If it is a sign + // bit comparison, it only demands the sign bit. + bool UnusedBit; + bool IsSignBit = isSignBitCheck(Pred, CI->getValue(), UnusedBit); + + // Canonicalizing a sign bit comparison that gets used in a branch, + // pessimizes codegen by generating branch on zero instruction instead + // of a test and branch. So we avoid canonicalizing in such situations + // because test and branch instruction has better branch displacement + // than compare and branch instruction. + if (!isBranchOnSignBitCheck(Cmp, IsSignBit) && !Cmp.isEquality()) { + if (auto *AI = Intersection.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder->getInt(*AI)); + if (auto *AD = Difference.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_NE, X, Builder->getInt(*AD)); + } + } + + return nullptr; +} + +/// Fold icmp (trunc X, Y), C. +Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, + Instruction *Trunc, + const APInt *C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Trunc->getOperand(0); + if (*C == 1 && C->getBitWidth() > 1) { + // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 + Value *V = nullptr; + if (Pred == ICmpInst::ICMP_SLT && match(X, m_Signum(m_Value(V)))) + return new ICmpInst(ICmpInst::ICMP_SLT, V, + ConstantInt::get(V->getType(), 1)); + } + + if (Cmp.isEquality() && Trunc->hasOneUse()) { + // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all + // of the high bits truncated out of x are known. + unsigned DstBits = Trunc->getType()->getScalarSizeInBits(), + SrcBits = X->getType()->getScalarSizeInBits(); + APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); + computeKnownBits(X, KnownZero, KnownOne, 0, &Cmp); + + // If all the high bits are known, we can do this xform. + if ((KnownZero | KnownOne).countLeadingOnes() >= SrcBits - DstBits) { + // Pull in the high bits from known-ones set. + APInt NewRHS = C->zext(SrcBits); + NewRHS |= KnownOne & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), NewRHS)); + } + } + + return nullptr; +} + +/// Fold icmp (xor X, Y), C. +Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, + BinaryOperator *Xor, + const APInt *C) { + Value *X = Xor->getOperand(0); + Value *Y = Xor->getOperand(1); + const APInt *XorC; + if (!match(Y, m_APInt(XorC))) + return nullptr; + + // If this is a comparison that tests the signbit (X < 0) or (x > -1), + // fold the xor. + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if ((Pred == ICmpInst::ICMP_SLT && *C == 0) || + (Pred == ICmpInst::ICMP_SGT && C->isAllOnesValue())) { + + // If the sign bit of the XorCst is not set, there is no change to + // the operation, just stop using the Xor. + if (!XorC->isNegative()) { + Cmp.setOperand(0, X); + Worklist.Add(Xor); + return &Cmp; + } + + // Was the old condition true if the operand is positive? + bool isTrueIfPositive = Pred == ICmpInst::ICMP_SGT; + + // If so, the new one isn't. + isTrueIfPositive ^= true; + + Constant *CmpConstant = cast<Constant>(Cmp.getOperand(1)); + if (isTrueIfPositive) + return new ICmpInst(ICmpInst::ICMP_SGT, X, SubOne(CmpConstant)); + else + return new ICmpInst(ICmpInst::ICMP_SLT, X, AddOne(CmpConstant)); + } + + if (Xor->hasOneUse()) { + // (icmp u/s (xor X SignBit), C) -> (icmp s/u X, (xor C SignBit)) + if (!Cmp.isEquality() && XorC->isSignBit()) { + Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() + : Cmp.getSignedPredicate(); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), *C ^ *XorC)); + } + + // (icmp u/s (xor X ~SignBit), C) -> (icmp s/u X, (xor C ~SignBit)) + if (!Cmp.isEquality() && XorC->isMaxSignedValue()) { + Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() + : Cmp.getSignedPredicate(); + Pred = Cmp.getSwappedPredicate(Pred); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), *C ^ *XorC)); + } + } + + // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C) + // iff -C is a power of 2 + if (Pred == ICmpInst::ICMP_UGT && *XorC == ~(*C) && (*C + 1).isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); + + // (icmp ult (xor X, C), -C) -> (icmp uge X, C) + // iff -C is a power of 2 + if (Pred == ICmpInst::ICMP_ULT && *XorC == -(*C) && C->isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_UGE, X, Y); + + return nullptr; +} + +/// Fold icmp (and (sh X, Y), C2), C1. +Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, + const APInt *C1, const APInt *C2) { + BinaryOperator *Shift = dyn_cast<BinaryOperator>(And->getOperand(0)); + if (!Shift || !Shift->isShift()) + return nullptr; + + // If this is: (X >> C3) & C2 != C1 (where any shift and any compare could + // exist), turn it into (X & (C2 << C3)) != (C1 << C3). This happens a LOT in + // code produced by the clang front-end, for bitfield access. + // This seemingly simple opportunity to fold away a shift turns out to be + // rather complicated. See PR17827 for details. + unsigned ShiftOpcode = Shift->getOpcode(); + bool IsShl = ShiftOpcode == Instruction::Shl; + const APInt *C3; + if (match(Shift->getOperand(1), m_APInt(C3))) { + bool CanFold = false; + if (ShiftOpcode == Instruction::AShr) { + // There may be some constraints that make this possible, but nothing + // simple has been discovered yet. + CanFold = false; + } else if (ShiftOpcode == Instruction::Shl) { + // For a left shift, we can fold if the comparison is not signed. We can + // also fold a signed comparison if the mask value and comparison value + // are not negative. These constraints may not be obvious, but we can + // prove that they are correct using an SMT solver. + if (!Cmp.isSigned() || (!C2->isNegative() && !C1->isNegative())) + CanFold = true; + } else if (ShiftOpcode == Instruction::LShr) { + // For a logical right shift, we can fold if the comparison is not signed. + // We can also fold a signed comparison if the shifted mask value and the + // shifted comparison value are not negative. These constraints may not be + // obvious, but we can prove that they are correct using an SMT solver. + if (!Cmp.isSigned() || + (!C2->shl(*C3).isNegative() && !C1->shl(*C3).isNegative())) + CanFold = true; + } + + if (CanFold) { + APInt NewCst = IsShl ? C1->lshr(*C3) : C1->shl(*C3); + APInt SameAsC1 = IsShl ? NewCst.shl(*C3) : NewCst.lshr(*C3); + // Check to see if we are shifting out any of the bits being compared. + if (SameAsC1 != *C1) { + // If we shifted bits out, the fold is not going to work out. As a + // special case, check to see if this means that the result is always + // true or false now. + if (Cmp.getPredicate() == ICmpInst::ICMP_EQ) + return replaceInstUsesWith(Cmp, ConstantInt::getFalse(Cmp.getType())); + if (Cmp.getPredicate() == ICmpInst::ICMP_NE) + return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType())); + } else { + Cmp.setOperand(1, ConstantInt::get(And->getType(), NewCst)); + APInt NewAndCst = IsShl ? C2->lshr(*C3) : C2->shl(*C3); + And->setOperand(1, ConstantInt::get(And->getType(), NewAndCst)); + And->setOperand(0, Shift->getOperand(0)); + Worklist.Add(Shift); // Shift is dead. + return &Cmp; + } + } + } + + // Turn ((X >> Y) & C2) == 0 into (X & (C2 << Y)) == 0. The latter is + // preferable because it allows the C2 << Y expression to be hoisted out of a + // loop if Y is invariant and X is not. + if (Shift->hasOneUse() && *C1 == 0 && Cmp.isEquality() && + !Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) { + // Compute C2 << Y. + Value *NewShift = + IsShl ? Builder->CreateLShr(And->getOperand(1), Shift->getOperand(1)) + : Builder->CreateShl(And->getOperand(1), Shift->getOperand(1)); + + // Compute X & (C2 << Y). + Value *NewAnd = Builder->CreateAnd(Shift->getOperand(0), NewShift); + Cmp.setOperand(0, NewAnd); + return &Cmp; + } + + return nullptr; +} + +/// Fold icmp (and X, C2), C1. +Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, + BinaryOperator *And, + const APInt *C1) { + const APInt *C2; + if (!match(And->getOperand(1), m_APInt(C2))) + return nullptr; + + if (!And->hasOneUse() || !And->getOperand(0)->hasOneUse()) + return nullptr; + + // If the LHS is an 'and' of a truncate and we can widen the and/compare to + // the input width without changing the value produced, eliminate the cast: + // + // icmp (and (trunc W), C2), C1 -> icmp (and W, C2'), C1' + // + // We can do this transformation if the constants do not have their sign bits + // set or if it is an equality comparison. Extending a relational comparison + // when we're checking the sign bit would not work. + Value *W; + if (match(And->getOperand(0), m_Trunc(m_Value(W))) && + (Cmp.isEquality() || (!C1->isNegative() && !C2->isNegative()))) { + // TODO: Is this a good transform for vectors? Wider types may reduce + // throughput. Should this transform be limited (even for scalars) by using + // ShouldChangeType()? + if (!Cmp.getType()->isVectorTy()) { + Type *WideType = W->getType(); + unsigned WideScalarBits = WideType->getScalarSizeInBits(); + Constant *ZextC1 = ConstantInt::get(WideType, C1->zext(WideScalarBits)); + Constant *ZextC2 = ConstantInt::get(WideType, C2->zext(WideScalarBits)); + Value *NewAnd = Builder->CreateAnd(W, ZextC2, And->getName()); + return new ICmpInst(Cmp.getPredicate(), NewAnd, ZextC1); + } + } + + if (Instruction *I = foldICmpAndShift(Cmp, And, C1, C2)) + return I; + + // (icmp pred (and (or (lshr A, B), A), 1), 0) --> + // (icmp pred (and A, (or (shl 1, B), 1), 0)) + // + // iff pred isn't signed + if (!Cmp.isSigned() && *C1 == 0 && match(And->getOperand(1), m_One())) { + Constant *One = cast<Constant>(And->getOperand(1)); + Value *Or = And->getOperand(0); + Value *A, *B, *LShr; + if (match(Or, m_Or(m_Value(LShr), m_Value(A))) && + match(LShr, m_LShr(m_Specific(A), m_Value(B)))) { + unsigned UsesRemoved = 0; + if (And->hasOneUse()) + ++UsesRemoved; + if (Or->hasOneUse()) + ++UsesRemoved; + if (LShr->hasOneUse()) + ++UsesRemoved; + + // Compute A & ((1 << B) | 1) + Value *NewOr = nullptr; + if (auto *C = dyn_cast<Constant>(B)) { + if (UsesRemoved >= 1) + NewOr = ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); + } else { + if (UsesRemoved >= 3) + NewOr = Builder->CreateOr(Builder->CreateShl(One, B, LShr->getName(), + /*HasNUW=*/true), + One, Or->getName()); + } + if (NewOr) { + Value *NewAnd = Builder->CreateAnd(A, NewOr, And->getName()); + Cmp.setOperand(0, NewAnd); + return &Cmp; + } + } + } + + // (X & C2) > C1 --> (X & C2) != 0, if any bit set in (X & C2) will produce a + // result greater than C1. + unsigned NumTZ = C2->countTrailingZeros(); + if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && NumTZ < C2->getBitWidth() && + APInt::getOneBitSet(C2->getBitWidth(), NumTZ).ugt(*C1)) { + Constant *Zero = Constant::getNullValue(And->getType()); + return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + } + + return nullptr; +} + +/// Fold icmp (and X, Y), C. +Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, + BinaryOperator *And, + const APInt *C) { + if (Instruction *I = foldICmpAndConstConst(Cmp, And, C)) + return I; + + // TODO: These all require that Y is constant too, so refactor with the above. + + // Try to optimize things like "A[i] & 42 == 0" to index computations. + Value *X = And->getOperand(0); + Value *Y = And->getOperand(1); + if (auto *LI = dyn_cast<LoadInst>(X)) + if (auto *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0))) + if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) + if (GV->isConstant() && GV->hasDefinitiveInitializer() && + !LI->isVolatile() && isa<ConstantInt>(Y)) { + ConstantInt *C2 = cast<ConstantInt>(Y); + if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV, Cmp, C2)) + return Res; + } + + if (!Cmp.isEquality()) + return nullptr; + + // X & -C == -C -> X > u ~C + // X & -C != -C -> X <= u ~C + // iff C is a power of 2 + if (Cmp.getOperand(1) == Y && (-(*C)).isPowerOf2()) { + auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT + : CmpInst::ICMP_ULE; + return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); + } + + // (X & C2) == 0 -> (trunc X) >= 0 + // (X & C2) != 0 -> (trunc X) < 0 + // iff C2 is a power of 2 and it masks the sign bit of a legal integer type. + const APInt *C2; + if (And->hasOneUse() && *C == 0 && match(Y, m_APInt(C2))) { + int32_t ExactLogBase2 = C2->exactLogBase2(); + if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { + Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1); + if (And->getType()->isVectorTy()) + NTy = VectorType::get(NTy, And->getType()->getVectorNumElements()); + Value *Trunc = Builder->CreateTrunc(X, NTy); + auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_SGE + : CmpInst::ICMP_SLT; + return new ICmpInst(NewPred, Trunc, Constant::getNullValue(NTy)); + } + } + + return nullptr; +} + +/// Fold icmp (or X, Y), C. +Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, + const APInt *C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (*C == 1) { + // icmp slt signum(V) 1 --> icmp slt V, 1 + Value *V = nullptr; + if (Pred == ICmpInst::ICMP_SLT && match(Or, m_Signum(m_Value(V)))) + return new ICmpInst(ICmpInst::ICMP_SLT, V, + ConstantInt::get(V->getType(), 1)); + } + + if (!Cmp.isEquality() || *C != 0 || !Or->hasOneUse()) + return nullptr; + + Value *P, *Q; + if (match(Or, m_Or(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Value(Q))))) { + // Simplify icmp eq (or (ptrtoint P), (ptrtoint Q)), 0 + // -> and (icmp eq P, null), (icmp eq Q, null). + Value *CmpP = + Builder->CreateICmp(Pred, P, ConstantInt::getNullValue(P->getType())); + Value *CmpQ = + Builder->CreateICmp(Pred, Q, ConstantInt::getNullValue(Q->getType())); + auto LogicOpc = Pred == ICmpInst::Predicate::ICMP_EQ ? Instruction::And + : Instruction::Or; + return BinaryOperator::Create(LogicOpc, CmpP, CmpQ); + } + + return nullptr; +} + +/// Fold icmp (mul X, Y), C. +Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, + BinaryOperator *Mul, + const APInt *C) { + const APInt *MulC; + if (!match(Mul->getOperand(1), m_APInt(MulC))) + return nullptr; + + // If this is a test of the sign bit and the multiply is sign-preserving with + // a constant operand, use the multiply LHS operand instead. + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (isSignTest(Pred, *C) && Mul->hasNoSignedWrap()) { + if (MulC->isNegative()) + Pred = ICmpInst::getSwappedPredicate(Pred); + return new ICmpInst(Pred, Mul->getOperand(0), + Constant::getNullValue(Mul->getType())); + } + + return nullptr; +} + +/// Fold icmp (shl 1, Y), C. +static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, + const APInt *C) { + Value *Y; + if (!match(Shl, m_Shl(m_One(), m_Value(Y)))) + return nullptr; + + Type *ShiftType = Shl->getType(); + uint32_t TypeBits = C->getBitWidth(); + bool CIsPowerOf2 = C->isPowerOf2(); + ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (Cmp.isUnsigned()) { + // (1 << Y) pred C -> Y pred Log2(C) + if (!CIsPowerOf2) { + // (1 << Y) < 30 -> Y <= 4 + // (1 << Y) <= 30 -> Y <= 4 + // (1 << Y) >= 30 -> Y > 4 + // (1 << Y) > 30 -> Y > 4 + if (Pred == ICmpInst::ICMP_ULT) + Pred = ICmpInst::ICMP_ULE; + else if (Pred == ICmpInst::ICMP_UGE) + Pred = ICmpInst::ICMP_UGT; + } + + // (1 << Y) >= 2147483648 -> Y >= 31 -> Y == 31 + // (1 << Y) < 2147483648 -> Y < 31 -> Y != 31 + unsigned CLog2 = C->logBase2(); + if (CLog2 == TypeBits - 1) { + if (Pred == ICmpInst::ICMP_UGE) + Pred = ICmpInst::ICMP_EQ; + else if (Pred == ICmpInst::ICMP_ULT) + Pred = ICmpInst::ICMP_NE; + } + return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); + } else if (Cmp.isSigned()) { + Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); + if (C->isAllOnesValue()) { + // (1 << Y) <= -1 -> Y == 31 + if (Pred == ICmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); + + // (1 << Y) > -1 -> Y != 31 + if (Pred == ICmpInst::ICMP_SGT) + return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); + } else if (!(*C)) { + // (1 << Y) < 0 -> Y == 31 + // (1 << Y) <= 0 -> Y == 31 + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); + + // (1 << Y) >= 0 -> Y != 31 + // (1 << Y) > 0 -> Y != 31 + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) + return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); + } + } else if (Cmp.isEquality() && CIsPowerOf2) { + return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C->logBase2())); + } + + return nullptr; +} + +/// Fold icmp (shl X, Y), C. +Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, + BinaryOperator *Shl, + const APInt *C) { + const APInt *ShiftVal; + if (Cmp.isEquality() && match(Shl->getOperand(0), m_APInt(ShiftVal))) + return foldICmpShlConstConst(Cmp, Shl->getOperand(1), *C, *ShiftVal); + + const APInt *ShiftAmt; + if (!match(Shl->getOperand(1), m_APInt(ShiftAmt))) + return foldICmpShlOne(Cmp, Shl, C); + + // Check that the shift amount is in range. If not, don't perform undefined + // shifts. When the shift is visited, it will be simplified. + unsigned TypeBits = C->getBitWidth(); + if (ShiftAmt->uge(TypeBits)) + return nullptr; + + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Shl->getOperand(0); + if (Cmp.isEquality()) { + // If the shift is NUW, then it is just shifting out zeros, no need for an + // AND. + Constant *LShrC = ConstantInt::get(Shl->getType(), C->lshr(*ShiftAmt)); + if (Shl->hasNoUnsignedWrap()) + return new ICmpInst(Pred, X, LShrC); + + // If the shift is NSW and we compare to 0, then it is just shifting out + // sign bits, no need for an AND either. + if (Shl->hasNoSignedWrap() && *C == 0) + return new ICmpInst(Pred, X, LShrC); + + if (Shl->hasOneUse()) { + // Otherwise, strength reduce the shift into an and. + Constant *Mask = ConstantInt::get(Shl->getType(), + APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); + + Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); + return new ICmpInst(Pred, And, LShrC); + } + } + + // If this is a signed comparison to 0 and the shift is sign preserving, + // use the shift LHS operand instead; isSignTest may change 'Pred', so only + // do that if we're sure to not continue on in this function. + if (Shl->hasNoSignedWrap() && isSignTest(Pred, *C)) + return new ICmpInst(Pred, X, Constant::getNullValue(X->getType())); + + // Otherwise, if this is a comparison of the sign bit, simplify to and/test. + bool TrueIfSigned = false; + if (Shl->hasOneUse() && isSignBitCheck(Pred, *C, TrueIfSigned)) { + // (X << 31) <s 0 --> (X & 1) != 0 + Constant *Mask = ConstantInt::get( + X->getType(), + APInt::getOneBitSet(TypeBits, TypeBits - ShiftAmt->getZExtValue() - 1)); + Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); + return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, + And, Constant::getNullValue(And->getType())); + } + + // When the shift is nuw and pred is >u or <=u, comparison only really happens + // in the pre-shifted bits. Since InstSimplify canonicalizes <=u into <u, the + // <=u case can be further converted to match <u (see below). + if (Shl->hasNoUnsignedWrap() && + (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT)) { + // Derivation for the ult case: + // (X << S) <=u C is equiv to X <=u (C >> S) for all C + // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u + // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 + assert((Pred != ICmpInst::ICMP_ULT || C->ugt(0)) && + "Encountered `ult 0` that should have been eliminated by " + "InstSimplify."); + APInt ShiftedC = Pred == ICmpInst::ICMP_ULT ? (*C - 1).lshr(*ShiftAmt) + 1 + : C->lshr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), ShiftedC)); + } + + // Transform (icmp pred iM (shl iM %v, N), C) + // -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (C>>N)) + // Transform the shl to a trunc if (trunc (C>>N)) has no loss and M-N. + // This enables us to get rid of the shift in favor of a trunc that may be + // free on the target. It has the additional benefit of comparing to a + // smaller constant that may be more target-friendly. + unsigned Amt = ShiftAmt->getLimitedValue(TypeBits - 1); + if (Shl->hasOneUse() && Amt != 0 && C->countTrailingZeros() >= Amt && + DL.isLegalInteger(TypeBits - Amt)) { + Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); + if (X->getType()->isVectorTy()) + TruncTy = VectorType::get(TruncTy, X->getType()->getVectorNumElements()); + Constant *NewC = + ConstantInt::get(TruncTy, C->ashr(*ShiftAmt).trunc(TypeBits - Amt)); + return new ICmpInst(Pred, Builder->CreateTrunc(X, TruncTy), NewC); + } + + return nullptr; +} + +/// Fold icmp ({al}shr X, Y), C. +Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, + BinaryOperator *Shr, + const APInt *C) { + // An exact shr only shifts out zero bits, so: + // icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0 + Value *X = Shr->getOperand(0); + CmpInst::Predicate Pred = Cmp.getPredicate(); + if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() && *C == 0) + return new ICmpInst(Pred, X, Cmp.getOperand(1)); + + const APInt *ShiftVal; + if (Cmp.isEquality() && match(Shr->getOperand(0), m_APInt(ShiftVal))) + return foldICmpShrConstConst(Cmp, Shr->getOperand(1), *C, *ShiftVal); + + const APInt *ShiftAmt; + if (!match(Shr->getOperand(1), m_APInt(ShiftAmt))) + return nullptr; + + // Check that the shift amount is in range. If not, don't perform undefined + // shifts. When the shift is visited it will be simplified. + unsigned TypeBits = C->getBitWidth(); + unsigned ShAmtVal = ShiftAmt->getLimitedValue(TypeBits); + if (ShAmtVal >= TypeBits || ShAmtVal == 0) + return nullptr; + + bool IsAShr = Shr->getOpcode() == Instruction::AShr; + if (!Cmp.isEquality()) { + // If we have an unsigned comparison and an ashr, we can't simplify this. + // Similarly for signed comparisons with lshr. + if (Cmp.isSigned() != IsAShr) + return nullptr; + + // Otherwise, all lshr and most exact ashr's are equivalent to a udiv/sdiv + // by a power of 2. Since we already have logic to simplify these, + // transform to div and then simplify the resultant comparison. + if (IsAShr && (!Shr->isExact() || ShAmtVal == TypeBits - 1)) + return nullptr; + + // Revisit the shift (to delete it). + Worklist.Add(Shr); + + Constant *DivCst = ConstantInt::get( + Shr->getType(), APInt::getOneBitSet(TypeBits, ShAmtVal)); + + Value *Tmp = IsAShr ? Builder->CreateSDiv(X, DivCst, "", Shr->isExact()) + : Builder->CreateUDiv(X, DivCst, "", Shr->isExact()); + + Cmp.setOperand(0, Tmp); + + // If the builder folded the binop, just return it. + BinaryOperator *TheDiv = dyn_cast<BinaryOperator>(Tmp); + if (!TheDiv) + return &Cmp; + + // Otherwise, fold this div/compare. + assert(TheDiv->getOpcode() == Instruction::SDiv || + TheDiv->getOpcode() == Instruction::UDiv); + + Instruction *Res = foldICmpDivConstant(Cmp, TheDiv, C); + assert(Res && "This div/cst should have folded!"); + return Res; + } + + // Handle equality comparisons of shift-by-constant. + + // If the comparison constant changes with the shift, the comparison cannot + // succeed (bits of the comparison constant cannot match the shifted value). + // This should be known by InstSimplify and already be folded to true/false. + assert(((IsAShr && C->shl(ShAmtVal).ashr(ShAmtVal) == *C) || + (!IsAShr && C->shl(ShAmtVal).lshr(ShAmtVal) == *C)) && + "Expected icmp+shr simplify did not occur."); + + // Check if the bits shifted out are known to be zero. If so, we can compare + // against the unshifted value: + // (X & 4) >> 1 == 2 --> (X & 4) == 4. + Constant *ShiftedCmpRHS = ConstantInt::get(Shr->getType(), *C << ShAmtVal); + if (Shr->hasOneUse()) { + if (Shr->isExact()) + return new ICmpInst(Pred, X, ShiftedCmpRHS); + + // Otherwise strength reduce the shift into an 'and'. + APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); + Constant *Mask = ConstantInt::get(Shr->getType(), Val); + Value *And = Builder->CreateAnd(X, Mask, Shr->getName() + ".mask"); + return new ICmpInst(Pred, And, ShiftedCmpRHS); + } + + return nullptr; +} + +/// Fold icmp (udiv X, Y), C. +Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, + BinaryOperator *UDiv, + const APInt *C) { + const APInt *C2; + if (!match(UDiv->getOperand(0), m_APInt(C2))) + return nullptr; + + assert(C2 != 0 && "udiv 0, X should have been simplified already."); + + // (icmp ugt (udiv C2, Y), C) -> (icmp ule Y, C2/(C+1)) + Value *Y = UDiv->getOperand(1); + if (Cmp.getPredicate() == ICmpInst::ICMP_UGT) { + assert(!C->isMaxValue() && + "icmp ugt X, UINT_MAX should have been simplified already."); + return new ICmpInst(ICmpInst::ICMP_ULE, Y, + ConstantInt::get(Y->getType(), C2->udiv(*C + 1))); + } + + // (icmp ult (udiv C2, Y), C) -> (icmp ugt Y, C2/C) + if (Cmp.getPredicate() == ICmpInst::ICMP_ULT) { + assert(C != 0 && "icmp ult X, 0 should have been simplified already."); + return new ICmpInst(ICmpInst::ICMP_UGT, Y, + ConstantInt::get(Y->getType(), C2->udiv(*C))); + } + + return nullptr; +} + +/// Fold icmp ({su}div X, Y), C. +Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, + BinaryOperator *Div, + const APInt *C) { + // Fold: icmp pred ([us]div X, C2), C -> range test + // Fold this div into the comparison, producing a range check. + // Determine, based on the divide type, what the range is being + // checked. If there is an overflow on the low or high side, remember + // it, otherwise compute the range [low, hi) bounding the new value. + // See: InsertRangeTest above for the kinds of replacements possible. + const APInt *C2; + if (!match(Div->getOperand(1), m_APInt(C2))) + return nullptr; // FIXME: If the operand types don't match the type of the divide // then don't attempt this transform. The code below doesn't have the // logic to deal with a signed divide and an unsigned compare (and - // vice versa). This is because (x /s C1) <s C2 produces different - // results than (x /s C1) <u C2 or (x /u C1) <s C2 or even - // (x /u C1) <u C2. Simply casting the operands and result won't + // vice versa). This is because (x /s C2) <s C produces different + // results than (x /s C2) <u C or (x /u C2) <s C or even + // (x /u C2) <u C. Simply casting the operands and result won't // work. :( The if statement below tests that condition and bails // if it finds it. - bool DivIsSigned = DivI->getOpcode() == Instruction::SDiv; - if (!ICI.isEquality() && DivIsSigned != ICI.isSigned()) + bool DivIsSigned = Div->getOpcode() == Instruction::SDiv; + if (!Cmp.isEquality() && DivIsSigned != Cmp.isSigned()) return nullptr; - if (DivRHS->isZero()) - return nullptr; // The ProdOV computation fails on divide by zero. - if (DivIsSigned && DivRHS->isAllOnesValue()) - return nullptr; // The overflow computation also screws up here - if (DivRHS->isOne()) { - // This eliminates some funny cases with INT_MIN. - ICI.setOperand(0, DivI->getOperand(0)); // X/1 == X. - return &ICI; - } - - // Compute Prod = CI * DivRHS. We are essentially solving an equation - // of form X/C1=C2. We solve for X by multiplying C1 (DivRHS) and - // C2 (CI). By solving for X we can turn this into a range check - // instead of computing a divide. + + // The ProdOV computation fails on divide by 0 and divide by -1. Cases with + // INT_MIN will also fail if the divisor is 1. Although folds of all these + // division-by-constant cases should be present, we can not assert that they + // have happened before we reach this icmp instruction. + if (*C2 == 0 || *C2 == 1 || (DivIsSigned && C2->isAllOnesValue())) + return nullptr; + + // TODO: We could do all of the computations below using APInt. + Constant *CmpRHS = cast<Constant>(Cmp.getOperand(1)); + Constant *DivRHS = cast<Constant>(Div->getOperand(1)); + + // Compute Prod = CmpRHS * DivRHS. We are essentially solving an equation of + // form X / C2 = C. We solve for X by multiplying C2 (DivRHS) and C (CmpRHS). + // By solving for X, we can turn this into a range check instead of computing + // a divide. Constant *Prod = ConstantExpr::getMul(CmpRHS, DivRHS); - // Determine if the product overflows by seeing if the product is - // not equal to the divide. Make sure we do the same kind of divide - // as in the LHS instruction that we're folding. - bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS) : - ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS; + // Determine if the product overflows by seeing if the product is not equal to + // the divide. Make sure we do the same kind of divide as in the LHS + // instruction that we're folding. + bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS) + : ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS; - // Get the ICmp opcode - ICmpInst::Predicate Pred = ICI.getPredicate(); + ICmpInst::Predicate Pred = Cmp.getPredicate(); // If the division is known to be exact, then there is no remainder from the // divide, so the covered range size is unit, otherwise it is the divisor. - ConstantInt *RangeSize = DivI->isExact() ? getOne(Prod) : DivRHS; + Constant *RangeSize = + Div->isExact() ? ConstantInt::get(Div->getType(), 1) : DivRHS; // Figure out the interval that is being checked. For example, a comparison // like "X /u 5 == 0" is really checking that X is in the interval [0, 5). @@ -1245,1134 +2184,1094 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, if (!HiOverflow) { // If this is not an exact divide, then many values in the range collapse // to the same result value. - HiOverflow = AddWithOverflow(HiBound, LoBound, RangeSize, false); + HiOverflow = addWithOverflow(HiBound, LoBound, RangeSize, false); } - } else if (DivRHS->getValue().isStrictlyPositive()) { // Divisor is > 0. - if (CmpRHSV == 0) { // (X / pos) op 0 + } else if (C2->isStrictlyPositive()) { // Divisor is > 0. + if (*C == 0) { // (X / pos) op 0 // Can't overflow. e.g. X/2 op 0 --> [-1, 2) LoBound = ConstantExpr::getNeg(SubOne(RangeSize)); HiBound = RangeSize; - } else if (CmpRHSV.isStrictlyPositive()) { // (X / pos) op pos + } else if (C->isStrictlyPositive()) { // (X / pos) op pos LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) HiOverflow = LoOverflow = ProdOV; if (!HiOverflow) - HiOverflow = AddWithOverflow(HiBound, Prod, RangeSize, true); + HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true); } else { // (X / pos) op neg // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) HiBound = AddOne(Prod); LoOverflow = HiOverflow = ProdOV ? -1 : 0; if (!LoOverflow) { - ConstantInt *DivNeg =cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); - LoOverflow = AddWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0; + Constant *DivNeg = ConstantExpr::getNeg(RangeSize); + LoOverflow = addWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0; } } - } else if (DivRHS->isNegative()) { // Divisor is < 0. - if (DivI->isExact()) - RangeSize = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); - if (CmpRHSV == 0) { // (X / neg) op 0 + } else if (C2->isNegative()) { // Divisor is < 0. + if (Div->isExact()) + RangeSize = ConstantExpr::getNeg(RangeSize); + if (*C == 0) { // (X / neg) op 0 // e.g. X/-5 op 0 --> [-4, 5) LoBound = AddOne(RangeSize); - HiBound = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); + HiBound = ConstantExpr::getNeg(RangeSize); if (HiBound == DivRHS) { // -INTMIN = INTMIN HiOverflow = 1; // [INTMIN+1, overflow) HiBound = nullptr; // e.g. X/INTMIN = 0 --> X > INTMIN } - } else if (CmpRHSV.isStrictlyPositive()) { // (X / neg) op pos + } else if (C->isStrictlyPositive()) { // (X / neg) op pos // e.g. X/-5 op 3 --> [-19, -14) HiBound = AddOne(Prod); HiOverflow = LoOverflow = ProdOV ? -1 : 0; if (!LoOverflow) - LoOverflow = AddWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; + LoOverflow = addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; } else { // (X / neg) op neg LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) LoOverflow = HiOverflow = ProdOV; if (!HiOverflow) - HiOverflow = SubWithOverflow(HiBound, Prod, RangeSize, true); + HiOverflow = subWithOverflow(HiBound, Prod, RangeSize, true); } // Dividing by a negative swaps the condition. LT <-> GT Pred = ICmpInst::getSwappedPredicate(Pred); } - Value *X = DivI->getOperand(0); + Value *X = Div->getOperand(0); switch (Pred) { - default: llvm_unreachable("Unhandled icmp opcode!"); - case ICmpInst::ICMP_EQ: - if (LoOverflow && HiOverflow) - return replaceInstUsesWith(ICI, Builder->getFalse()); - if (HiOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, LoBound); - if (LoOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, HiBound); - return replaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, - DivIsSigned, true)); - case ICmpInst::ICMP_NE: - if (LoOverflow && HiOverflow) - return replaceInstUsesWith(ICI, Builder->getTrue()); - if (HiOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, LoBound); - if (LoOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, HiBound); - return replaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, - DivIsSigned, false)); - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_SLT: - if (LoOverflow == +1) // Low bound is greater than input range. - return replaceInstUsesWith(ICI, Builder->getTrue()); - if (LoOverflow == -1) // Low bound is less than input range. - return replaceInstUsesWith(ICI, Builder->getFalse()); - return new ICmpInst(Pred, X, LoBound); - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_SGT: - if (HiOverflow == +1) // High bound greater than input range. - return replaceInstUsesWith(ICI, Builder->getFalse()); - if (HiOverflow == -1) // High bound less than input range. - return replaceInstUsesWith(ICI, Builder->getTrue()); - if (Pred == ICmpInst::ICMP_UGT) - return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); - return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); + default: llvm_unreachable("Unhandled icmp opcode!"); + case ICmpInst::ICMP_EQ: + if (LoOverflow && HiOverflow) + return replaceInstUsesWith(Cmp, Builder->getFalse()); + if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : + ICmpInst::ICMP_UGE, X, LoBound); + if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : + ICmpInst::ICMP_ULT, X, HiBound); + return replaceInstUsesWith( + Cmp, insertRangeTest(X, LoBound->getUniqueInteger(), + HiBound->getUniqueInteger(), DivIsSigned, true)); + case ICmpInst::ICMP_NE: + if (LoOverflow && HiOverflow) + return replaceInstUsesWith(Cmp, Builder->getTrue()); + if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : + ICmpInst::ICMP_ULT, X, LoBound); + if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : + ICmpInst::ICMP_UGE, X, HiBound); + return replaceInstUsesWith(Cmp, + insertRangeTest(X, LoBound->getUniqueInteger(), + HiBound->getUniqueInteger(), + DivIsSigned, false)); + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: + if (LoOverflow == +1) // Low bound is greater than input range. + return replaceInstUsesWith(Cmp, Builder->getTrue()); + if (LoOverflow == -1) // Low bound is less than input range. + return replaceInstUsesWith(Cmp, Builder->getFalse()); + return new ICmpInst(Pred, X, LoBound); + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: + if (HiOverflow == +1) // High bound greater than input range. + return replaceInstUsesWith(Cmp, Builder->getFalse()); + if (HiOverflow == -1) // High bound less than input range. + return replaceInstUsesWith(Cmp, Builder->getTrue()); + if (Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); + return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); } + + return nullptr; } -/// Handle "icmp(([al]shr X, cst1), cst2)". -Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr, - ConstantInt *ShAmt) { - const APInt &CmpRHSV = cast<ConstantInt>(ICI.getOperand(1))->getValue(); +/// Fold icmp (sub X, Y), C. +Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, + BinaryOperator *Sub, + const APInt *C) { + Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1); + ICmpInst::Predicate Pred = Cmp.getPredicate(); - // Check that the shift amount is in range. If not, don't perform - // undefined shifts. When the shift is visited it will be - // simplified. - uint32_t TypeBits = CmpRHSV.getBitWidth(); - uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); - if (ShAmtVal >= TypeBits || ShAmtVal == 0) + // The following transforms are only worth it if the only user of the subtract + // is the icmp. + if (!Sub->hasOneUse()) return nullptr; - if (!ICI.isEquality()) { - // If we have an unsigned comparison and an ashr, we can't simplify this. - // Similarly for signed comparisons with lshr. - if (ICI.isSigned() != (Shr->getOpcode() == Instruction::AShr)) - return nullptr; - - // Otherwise, all lshr and most exact ashr's are equivalent to a udiv/sdiv - // by a power of 2. Since we already have logic to simplify these, - // transform to div and then simplify the resultant comparison. - if (Shr->getOpcode() == Instruction::AShr && - (!Shr->isExact() || ShAmtVal == TypeBits - 1)) - return nullptr; - - // Revisit the shift (to delete it). - Worklist.Add(Shr); - - Constant *DivCst = - ConstantInt::get(Shr->getType(), APInt::getOneBitSet(TypeBits, ShAmtVal)); + if (Sub->hasNoSignedWrap()) { + // (icmp sgt (sub nsw X, Y), -1) -> (icmp sge X, Y) + if (Pred == ICmpInst::ICMP_SGT && C->isAllOnesValue()) + return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); - Value *Tmp = - Shr->getOpcode() == Instruction::AShr ? - Builder->CreateSDiv(Shr->getOperand(0), DivCst, "", Shr->isExact()) : - Builder->CreateUDiv(Shr->getOperand(0), DivCst, "", Shr->isExact()); + // (icmp sgt (sub nsw X, Y), 0) -> (icmp sgt X, Y) + if (Pred == ICmpInst::ICMP_SGT && *C == 0) + return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); - ICI.setOperand(0, Tmp); + // (icmp slt (sub nsw X, Y), 0) -> (icmp slt X, Y) + if (Pred == ICmpInst::ICMP_SLT && *C == 0) + return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); - // If the builder folded the binop, just return it. - BinaryOperator *TheDiv = dyn_cast<BinaryOperator>(Tmp); - if (!TheDiv) - return &ICI; - - // Otherwise, fold this div/compare. - assert(TheDiv->getOpcode() == Instruction::SDiv || - TheDiv->getOpcode() == Instruction::UDiv); - - Instruction *Res = FoldICmpDivCst(ICI, TheDiv, cast<ConstantInt>(DivCst)); - assert(Res && "This div/cst should have folded!"); - return Res; + // (icmp slt (sub nsw X, Y), 1) -> (icmp sle X, Y) + if (Pred == ICmpInst::ICMP_SLT && *C == 1) + return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); } - // If we are comparing against bits always shifted out, the - // comparison cannot succeed. - APInt Comp = CmpRHSV << ShAmtVal; - ConstantInt *ShiftedCmpRHS = Builder->getInt(Comp); - if (Shr->getOpcode() == Instruction::LShr) - Comp = Comp.lshr(ShAmtVal); - else - Comp = Comp.ashr(ShAmtVal); + const APInt *C2; + if (!match(X, m_APInt(C2))) + return nullptr; - if (Comp != CmpRHSV) { // Comparing against a bit that we know is zero. - bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; - Constant *Cst = Builder->getInt1(IsICMP_NE); - return replaceInstUsesWith(ICI, Cst); - } + // C2 - Y <u C -> (Y | (C - 1)) == C2 + // iff (C2 & (C - 1)) == C - 1 and C is a power of 2 + if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && + (*C2 & (*C - 1)) == (*C - 1)) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder->CreateOr(Y, *C - 1), X); - // Otherwise, check to see if the bits shifted out are known to be zero. - // If so, we can compare against the unshifted value: - // (X & 4) >> 1 == 2 --> (X & 4) == 4. - if (Shr->hasOneUse() && Shr->isExact()) - return new ICmpInst(ICI.getPredicate(), Shr->getOperand(0), ShiftedCmpRHS); + // C2 - Y >u C -> (Y | C) != C2 + // iff C2 & C == C and C + 1 is a power of 2 + if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == *C) + return new ICmpInst(ICmpInst::ICMP_NE, Builder->CreateOr(Y, *C), X); - if (Shr->hasOneUse()) { - // Otherwise strength reduce the shift into an and. - APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); - Constant *Mask = Builder->getInt(Val); - - Value *And = Builder->CreateAnd(Shr->getOperand(0), - Mask, Shr->getName()+".mask"); - return new ICmpInst(ICI.getPredicate(), And, ShiftedCmpRHS); - } return nullptr; } -/// Handle "(icmp eq/ne (ashr/lshr const2, A), const1)" -> -/// (icmp eq/ne A, Log2(const2/const1)) -> -/// (icmp eq/ne A, Log2(const2) - Log2(const1)). -Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, - ConstantInt *CI1, - ConstantInt *CI2) { - assert(I.isEquality() && "Cannot fold icmp gt/lt"); - - auto getConstant = [&I, this](bool IsTrue) { - if (I.getPredicate() == I.ICMP_NE) - IsTrue = !IsTrue; - return replaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); - }; - - auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { - if (I.getPredicate() == I.ICMP_NE) - Pred = CmpInst::getInversePredicate(Pred); - return new ICmpInst(Pred, LHS, RHS); - }; +/// Fold icmp (add X, Y), C. +Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, + BinaryOperator *Add, + const APInt *C) { + Value *Y = Add->getOperand(1); + const APInt *C2; + if (Cmp.isEquality() || !match(Y, m_APInt(C2))) + return nullptr; - const APInt &AP1 = CI1->getValue(); - const APInt &AP2 = CI2->getValue(); + // Fold icmp pred (add X, C2), C. + Value *X = Add->getOperand(0); + Type *Ty = Add->getType(); + auto CR = + ConstantRange::makeExactICmpRegion(Cmp.getPredicate(), *C).subtract(*C2); + const APInt &Upper = CR.getUpper(); + const APInt &Lower = CR.getLower(); + if (Cmp.isSigned()) { + if (Lower.isSignBit()) + return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::get(Ty, Upper)); + if (Upper.isSignBit()) + return new ICmpInst(ICmpInst::ICMP_SGE, X, ConstantInt::get(Ty, Lower)); + } else { + if (Lower.isMinValue()) + return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantInt::get(Ty, Upper)); + if (Upper.isMinValue()) + return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, Lower)); + } - // Don't bother doing any work for cases which InstSimplify handles. - if (AP2 == 0) + if (!Add->hasOneUse()) return nullptr; - bool IsAShr = isa<AShrOperator>(Op); - if (IsAShr) { - if (AP2.isAllOnesValue()) - return nullptr; - if (AP2.isNegative() != AP1.isNegative()) - return nullptr; - if (AP2.sgt(AP1)) - return nullptr; - } - if (!AP1) - // 'A' must be large enough to shift out the highest set bit. - return getICmp(I.ICMP_UGT, A, - ConstantInt::get(A->getType(), AP2.logBase2())); + // X+C <u C2 -> (X & -C2) == C + // iff C & (C2-1) == 0 + // C2 is a power of 2 + if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && C->isPowerOf2() && + (*C2 & (*C - 1)) == 0) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder->CreateAnd(X, -(*C)), + ConstantExpr::getNeg(cast<Constant>(Y))); + + // X+C >u C2 -> (X & ~C2) != C + // iff C & C2 == 0 + // C2+1 is a power of 2 + if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && + (*C2 & *C) == 0) + return new ICmpInst(ICmpInst::ICMP_NE, Builder->CreateAnd(X, ~(*C)), + ConstantExpr::getNeg(cast<Constant>(Y))); - if (AP1 == AP2) - return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + return nullptr; +} - int Shift; - if (IsAShr && AP1.isNegative()) - Shift = AP1.countLeadingOnes() - AP2.countLeadingOnes(); - else - Shift = AP1.countLeadingZeros() - AP2.countLeadingZeros(); +/// Try to fold integer comparisons with a constant operand: icmp Pred X, C +/// where X is some kind of instruction. +Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { + const APInt *C; + if (!match(Cmp.getOperand(1), m_APInt(C))) + return nullptr; - if (Shift > 0) { - if (IsAShr && AP1 == AP2.ashr(Shift)) { - // There are multiple solutions if we are comparing against -1 and the LHS - // of the ashr is not a power of two. - if (AP1.isAllOnesValue() && !AP2.isPowerOf2()) - return getICmp(I.ICMP_UGE, A, ConstantInt::get(A->getType(), Shift)); - return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); - } else if (AP1 == AP2.lshr(Shift)) { - return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); + BinaryOperator *BO; + if (match(Cmp.getOperand(0), m_BinOp(BO))) { + switch (BO->getOpcode()) { + case Instruction::Xor: + if (Instruction *I = foldICmpXorConstant(Cmp, BO, C)) + return I; + break; + case Instruction::And: + if (Instruction *I = foldICmpAndConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Or: + if (Instruction *I = foldICmpOrConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Mul: + if (Instruction *I = foldICmpMulConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Shl: + if (Instruction *I = foldICmpShlConstant(Cmp, BO, C)) + return I; + break; + case Instruction::LShr: + case Instruction::AShr: + if (Instruction *I = foldICmpShrConstant(Cmp, BO, C)) + return I; + break; + case Instruction::UDiv: + if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C)) + return I; + LLVM_FALLTHROUGH; + case Instruction::SDiv: + if (Instruction *I = foldICmpDivConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Sub: + if (Instruction *I = foldICmpSubConstant(Cmp, BO, C)) + return I; + break; + case Instruction::Add: + if (Instruction *I = foldICmpAddConstant(Cmp, BO, C)) + return I; + break; + default: + break; } + // TODO: These folds could be refactored to be part of the above calls. + if (Instruction *I = foldICmpBinOpEqualityWithConstant(Cmp, BO, C)) + return I; } - // Shifting const2 will never be equal to const1. - return getConstant(false); -} -/// Handle "(icmp eq/ne (shl const2, A), const1)" -> -/// (icmp eq/ne A, TrailingZeros(const1) - TrailingZeros(const2)). -Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, - ConstantInt *CI1, - ConstantInt *CI2) { - assert(I.isEquality() && "Cannot fold icmp gt/lt"); + Instruction *LHSI; + if (match(Cmp.getOperand(0), m_Instruction(LHSI)) && + LHSI->getOpcode() == Instruction::Trunc) + if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C)) + return I; - auto getConstant = [&I, this](bool IsTrue) { - if (I.getPredicate() == I.ICMP_NE) - IsTrue = !IsTrue; - return replaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue)); - }; + if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, C)) + return I; - auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) { - if (I.getPredicate() == I.ICMP_NE) - Pred = CmpInst::getInversePredicate(Pred); - return new ICmpInst(Pred, LHS, RHS); - }; - - const APInt &AP1 = CI1->getValue(); - const APInt &AP2 = CI2->getValue(); + return nullptr; +} - // Don't bother doing any work for cases which InstSimplify handles. - if (AP2 == 0) +/// Fold an icmp equality instruction with binary operator LHS and constant RHS: +/// icmp eq/ne BO, C. +Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, + BinaryOperator *BO, + const APInt *C) { + // TODO: Some of these folds could work with arbitrary constants, but this + // function is limited to scalar and vector splat constants. + if (!Cmp.isEquality()) return nullptr; - unsigned AP2TrailingZeros = AP2.countTrailingZeros(); - - if (!AP1 && AP2TrailingZeros != 0) - return getICmp(I.ICMP_UGE, A, - ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros)); + ICmpInst::Predicate Pred = Cmp.getPredicate(); + bool isICMP_NE = Pred == ICmpInst::ICMP_NE; + Constant *RHS = cast<Constant>(Cmp.getOperand(1)); + Value *BOp0 = BO->getOperand(0), *BOp1 = BO->getOperand(1); + + switch (BO->getOpcode()) { + case Instruction::SRem: + // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. + if (*C == 0 && BO->hasOneUse()) { + const APInt *BOC; + if (match(BOp1, m_APInt(BOC)) && BOC->sgt(1) && BOC->isPowerOf2()) { + Value *NewRem = Builder->CreateURem(BOp0, BOp1, BO->getName()); + return new ICmpInst(Pred, NewRem, + Constant::getNullValue(BO->getType())); + } + } + break; + case Instruction::Add: { + // Replace ((add A, B) != C) with (A != C-B) if B & C are constants. + const APInt *BOC; + if (match(BOp1, m_APInt(BOC))) { + if (BO->hasOneUse()) { + Constant *SubC = ConstantExpr::getSub(RHS, cast<Constant>(BOp1)); + return new ICmpInst(Pred, BOp0, SubC); + } + } else if (*C == 0) { + // Replace ((add A, B) != 0) with (A != -B) if A or B is + // efficiently invertible, or if the add has just this one use. + if (Value *NegVal = dyn_castNegVal(BOp1)) + return new ICmpInst(Pred, BOp0, NegVal); + if (Value *NegVal = dyn_castNegVal(BOp0)) + return new ICmpInst(Pred, NegVal, BOp1); + if (BO->hasOneUse()) { + Value *Neg = Builder->CreateNeg(BOp1); + Neg->takeName(BO); + return new ICmpInst(Pred, BOp0, Neg); + } + } + break; + } + case Instruction::Xor: + if (BO->hasOneUse()) { + if (Constant *BOC = dyn_cast<Constant>(BOp1)) { + // For the xor case, we can xor two constants together, eliminating + // the explicit xor. + return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC)); + } else if (*C == 0) { + // Replace ((xor A, B) != 0) with (A != B) + return new ICmpInst(Pred, BOp0, BOp1); + } + } + break; + case Instruction::Sub: + if (BO->hasOneUse()) { + const APInt *BOC; + if (match(BOp0, m_APInt(BOC))) { + // Replace ((sub BOC, B) != C) with (B != BOC-C). + Constant *SubC = ConstantExpr::getSub(cast<Constant>(BOp0), RHS); + return new ICmpInst(Pred, BOp1, SubC); + } else if (*C == 0) { + // Replace ((sub A, B) != 0) with (A != B). + return new ICmpInst(Pred, BOp0, BOp1); + } + } + break; + case Instruction::Or: { + const APInt *BOC; + if (match(BOp1, m_APInt(BOC)) && BO->hasOneUse() && RHS->isAllOnesValue()) { + // Comparing if all bits outside of a constant mask are set? + // Replace (X | C) == -1 with (X & ~C) == ~C. + // This removes the -1 constant. + Constant *NotBOC = ConstantExpr::getNot(cast<Constant>(BOp1)); + Value *And = Builder->CreateAnd(BOp0, NotBOC); + return new ICmpInst(Pred, And, NotBOC); + } + break; + } + case Instruction::And: { + const APInt *BOC; + if (match(BOp1, m_APInt(BOC))) { + // If we have ((X & C) == C), turn it into ((X & C) != 0). + if (C == BOC && C->isPowerOf2()) + return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, + BO, Constant::getNullValue(RHS->getType())); + + // Don't perform the following transforms if the AND has multiple uses + if (!BO->hasOneUse()) + break; - if (AP1 == AP2) - return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType())); + // Replace (and X, (1 << size(X)-1) != 0) with x s< 0 + if (BOC->isSignBit()) { + Constant *Zero = Constant::getNullValue(BOp0->getType()); + auto NewPred = isICMP_NE ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; + return new ICmpInst(NewPred, BOp0, Zero); + } - // Get the distance between the lowest bits that are set. - int Shift = AP1.countTrailingZeros() - AP2TrailingZeros; + // ((X & ~7) == 0) --> X < 8 + if (*C == 0 && (~(*BOC) + 1).isPowerOf2()) { + Constant *NegBOC = ConstantExpr::getNeg(cast<Constant>(BOp1)); + auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; + return new ICmpInst(NewPred, BOp0, NegBOC); + } + } + break; + } + case Instruction::Mul: + if (*C == 0 && BO->hasNoSignedWrap()) { + const APInt *BOC; + if (match(BOp1, m_APInt(BOC)) && *BOC != 0) { + // The trivial case (mul X, 0) is handled by InstSimplify. + // General case : (mul X, C) != 0 iff X != 0 + // (mul X, C) == 0 iff X == 0 + return new ICmpInst(Pred, BOp0, Constant::getNullValue(RHS->getType())); + } + } + break; + case Instruction::UDiv: + if (*C == 0) { + // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) + auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; + return new ICmpInst(NewPred, BOp1, BOp0); + } + break; + default: + break; + } + return nullptr; +} - if (Shift > 0 && AP2.shl(Shift) == AP1) - return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); +/// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. +Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, + const APInt *C) { + IntrinsicInst *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0)); + if (!II || !Cmp.isEquality()) + return nullptr; - // Shifting const2 will never be equal to const1. - return getConstant(false); + // Handle icmp {eq|ne} <intrinsic>, intcst. + switch (II->getIntrinsicID()) { + case Intrinsic::bswap: + Worklist.Add(II); + Cmp.setOperand(0, II->getArgOperand(0)); + Cmp.setOperand(1, Builder->getInt(C->byteSwap())); + return &Cmp; + case Intrinsic::ctlz: + case Intrinsic::cttz: + // ctz(A) == bitwidth(A) -> A == 0 and likewise for != + if (*C == C->getBitWidth()) { + Worklist.Add(II); + Cmp.setOperand(0, II->getArgOperand(0)); + Cmp.setOperand(1, ConstantInt::getNullValue(II->getType())); + return &Cmp; + } + break; + case Intrinsic::ctpop: { + // popcount(A) == 0 -> A == 0 and likewise for != + // popcount(A) == bitwidth(A) -> A == -1 and likewise for != + bool IsZero = *C == 0; + if (IsZero || *C == C->getBitWidth()) { + Worklist.Add(II); + Cmp.setOperand(0, II->getArgOperand(0)); + auto *NewOp = IsZero ? Constant::getNullValue(II->getType()) + : Constant::getAllOnesValue(II->getType()); + Cmp.setOperand(1, NewOp); + return &Cmp; + } + break; + } + default: + break; + } + return nullptr; } -/// Handle "icmp (instr, intcst)". -Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, - Instruction *LHSI, - ConstantInt *RHS) { - const APInt &RHSV = RHS->getValue(); +/// Handle icmp with constant (but not simple integer constant) RHS. +Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Constant *RHSC = dyn_cast<Constant>(Op1); + Instruction *LHSI = dyn_cast<Instruction>(Op0); + if (!RHSC || !LHSI) + return nullptr; switch (LHSI->getOpcode()) { - case Instruction::Trunc: - if (RHS->isOne() && RHSV.getBitWidth() > 1) { - // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 - Value *V = nullptr; - if (ICI.getPredicate() == ICmpInst::ICMP_SLT && - match(LHSI->getOperand(0), m_Signum(m_Value(V)))) - return new ICmpInst(ICmpInst::ICMP_SLT, V, - ConstantInt::get(V->getType(), 1)); + case Instruction::GetElementPtr: + // icmp pred GEP (P, int 0, int 0, int 0), null -> icmp pred P, null + if (RHSC->isNullValue() && + cast<GetElementPtrInst>(LHSI)->hasAllZeroIndices()) + return new ICmpInst( + I.getPredicate(), LHSI->getOperand(0), + Constant::getNullValue(LHSI->getOperand(0)->getType())); + break; + case Instruction::PHI: + // Only fold icmp into the PHI if the phi and icmp are in the same + // block. If in the same block, we're encouraging jump threading. If + // not, we are just pessimizing the code by making an i1 phi. + if (LHSI->getParent() == I.getParent()) + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + break; + case Instruction::Select: { + // If either operand of the select is a constant, we can fold the + // comparison into the select arms, which will cause one to be + // constant folded and the select turned into a bitwise or. + Value *Op1 = nullptr, *Op2 = nullptr; + ConstantInt *CI = nullptr; + if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) { + Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + CI = dyn_cast<ConstantInt>(Op1); } - if (ICI.isEquality() && LHSI->hasOneUse()) { - // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all - // of the high bits truncated out of x are known. - unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(), - SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits(); - APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0); - computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne, 0, &ICI); - - // If all the high bits are known, we can do this xform. - if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) { - // Pull in the high bits from known-ones set. - APInt NewRHS = RHS->getValue().zext(SrcBits); - NewRHS |= KnownOne & APInt::getHighBitsSet(SrcBits, SrcBits-DstBits); - return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), - Builder->getInt(NewRHS)); - } + if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) { + Op2 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + CI = dyn_cast<ConstantInt>(Op2); + } + + // We only want to perform this transformation if it will not lead to + // additional code. This is true if either both sides of the select + // fold to a constant (in which case the icmp is replaced with a select + // which will usually simplify) or this is the only user of the + // select (in which case we are trading a select+icmp for a simpler + // select+icmp) or all uses of the select can be replaced based on + // dominance information ("Global cases"). + bool Transform = false; + if (Op1 && Op2) + Transform = true; + else if (Op1 || Op2) { + // Local case + if (LHSI->hasOneUse()) + Transform = true; + // Global cases + else if (CI && !CI->isZero()) + // When Op1 is constant try replacing select with second operand. + // Otherwise Op2 is constant and try replacing select with first + // operand. + Transform = + replacedSelectWithOperand(cast<SelectInst>(LHSI), &I, Op1 ? 2 : 1); + } + if (Transform) { + if (!Op1) + Op1 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(1), RHSC, + I.getName()); + if (!Op2) + Op2 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(2), RHSC, + I.getName()); + return SelectInst::Create(LHSI->getOperand(0), Op1, Op2); } break; + } + case Instruction::IntToPtr: + // icmp pred inttoptr(X), null -> icmp pred X, 0 + if (RHSC->isNullValue() && + DL.getIntPtrType(RHSC->getType()) == LHSI->getOperand(0)->getType()) + return new ICmpInst( + I.getPredicate(), LHSI->getOperand(0), + Constant::getNullValue(LHSI->getOperand(0)->getType())); + break; - case Instruction::Xor: // (icmp pred (xor X, XorCst), CI) - if (ConstantInt *XorCst = dyn_cast<ConstantInt>(LHSI->getOperand(1))) { - // If this is a comparison that tests the signbit (X < 0) or (x > -1), - // fold the xor. - if ((ICI.getPredicate() == ICmpInst::ICMP_SLT && RHSV == 0) || - (ICI.getPredicate() == ICmpInst::ICMP_SGT && RHSV.isAllOnesValue())) { - Value *CompareVal = LHSI->getOperand(0); - - // If the sign bit of the XorCst is not set, there is no change to - // the operation, just stop using the Xor. - if (!XorCst->isNegative()) { - ICI.setOperand(0, CompareVal); - Worklist.Add(LHSI); - return &ICI; - } + case Instruction::Load: + // Try to optimize things like "A[i] > 4" to index computations. + if (GetElementPtrInst *GEP = + dyn_cast<GetElementPtrInst>(LHSI->getOperand(0))) { + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) + if (GV->isConstant() && GV->hasDefinitiveInitializer() && + !cast<LoadInst>(LHSI)->isVolatile()) + if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV, I)) + return Res; + } + break; + } - // Was the old condition true if the operand is positive? - bool isTrueIfPositive = ICI.getPredicate() == ICmpInst::ICMP_SGT; + return nullptr; +} - // If so, the new one isn't. - isTrueIfPositive ^= true; +/// Try to fold icmp (binop), X or icmp X, (binop). +Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (isTrueIfPositive) - return new ICmpInst(ICmpInst::ICMP_SGT, CompareVal, - SubOne(RHS)); - else - return new ICmpInst(ICmpInst::ICMP_SLT, CompareVal, - AddOne(RHS)); - } + // Special logic for binary operators. + BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0); + BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1); + if (!BO0 && !BO1) + return nullptr; - if (LHSI->hasOneUse()) { - // (icmp u/s (xor A SignBit), C) -> (icmp s/u A, (xor C SignBit)) - if (!ICI.isEquality() && XorCst->getValue().isSignBit()) { - const APInt &SignBit = XorCst->getValue(); - ICmpInst::Predicate Pred = ICI.isSigned() - ? ICI.getUnsignedPredicate() - : ICI.getSignedPredicate(); - return new ICmpInst(Pred, LHSI->getOperand(0), - Builder->getInt(RHSV ^ SignBit)); - } + CmpInst::Predicate Pred = I.getPredicate(); + bool NoOp0WrapProblem = false, NoOp1WrapProblem = false; + if (BO0 && isa<OverflowingBinaryOperator>(BO0)) + NoOp0WrapProblem = + ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && BO0->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && BO0->hasNoSignedWrap()); + if (BO1 && isa<OverflowingBinaryOperator>(BO1)) + NoOp1WrapProblem = + ICmpInst::isEquality(Pred) || + (CmpInst::isUnsigned(Pred) && BO1->hasNoUnsignedWrap()) || + (CmpInst::isSigned(Pred) && BO1->hasNoSignedWrap()); - // (icmp u/s (xor A ~SignBit), C) -> (icmp s/u (xor C ~SignBit), A) - if (!ICI.isEquality() && XorCst->isMaxValue(true)) { - const APInt &NotSignBit = XorCst->getValue(); - ICmpInst::Predicate Pred = ICI.isSigned() - ? ICI.getUnsignedPredicate() - : ICI.getSignedPredicate(); - Pred = ICI.getSwappedPredicate(Pred); - return new ICmpInst(Pred, LHSI->getOperand(0), - Builder->getInt(RHSV ^ NotSignBit)); - } - } + // Analyze the case when either Op0 or Op1 is an add instruction. + // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null). + Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; + if (BO0 && BO0->getOpcode() == Instruction::Add) { + A = BO0->getOperand(0); + B = BO0->getOperand(1); + } + if (BO1 && BO1->getOpcode() == Instruction::Add) { + C = BO1->getOperand(0); + D = BO1->getOperand(1); + } - // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C) - // iff -C is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_UGT && - XorCst->getValue() == ~RHSV && (RHSV + 1).isPowerOf2()) - return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), XorCst); - - // (icmp ult (xor X, C), -C) -> (icmp uge X, C) - // iff -C is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_ULT && - XorCst->getValue() == -RHSV && RHSV.isPowerOf2()) - return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), XorCst); + // icmp (X+cst) < 0 --> X < -cst + if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero())) + if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B)) + if (!RHSC->isMinValue(/*isSigned=*/true)) + return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC)); + + // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. + if ((A == Op1 || B == Op1) && NoOp0WrapProblem) + return new ICmpInst(Pred, A == Op1 ? B : A, + Constant::getNullValue(Op1->getType())); + + // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. + if ((C == Op0 || D == Op0) && NoOp1WrapProblem) + return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()), + C == Op0 ? D : C); + + // icmp (X+Y), (X+Z) -> icmp Y, Z for equalities or if there is no overflow. + if (A && C && (A == C || A == D || B == C || B == D) && NoOp0WrapProblem && + NoOp1WrapProblem && + // Try not to increase register pressure. + BO0->hasOneUse() && BO1->hasOneUse()) { + // Determine Y and Z in the form icmp (X+Y), (X+Z). + Value *Y, *Z; + if (A == C) { + // C + B == C + D -> B == D + Y = B; + Z = D; + } else if (A == D) { + // D + B == C + D -> B == C + Y = B; + Z = C; + } else if (B == C) { + // A + C == C + D -> A == D + Y = A; + Z = D; + } else { + assert(B == D); + // A + D == C + D -> A == C + Y = A; + Z = C; } - break; - case Instruction::And: // (icmp pred (and X, AndCst), RHS) - if (LHSI->hasOneUse() && isa<ConstantInt>(LHSI->getOperand(1)) && - LHSI->getOperand(0)->hasOneUse()) { - ConstantInt *AndCst = cast<ConstantInt>(LHSI->getOperand(1)); - - // If the LHS is an AND of a truncating cast, we can widen the - // and/compare to be the input width without changing the value - // produced, eliminating a cast. - if (TruncInst *Cast = dyn_cast<TruncInst>(LHSI->getOperand(0))) { - // We can do this transformation if either the AND constant does not - // have its sign bit set or if it is an equality comparison. - // Extending a relational comparison when we're checking the sign - // bit would not work. - if (ICI.isEquality() || - (!AndCst->isNegative() && RHSV.isNonNegative())) { - Value *NewAnd = - Builder->CreateAnd(Cast->getOperand(0), - ConstantExpr::getZExt(AndCst, Cast->getSrcTy())); - NewAnd->takeName(LHSI); - return new ICmpInst(ICI.getPredicate(), NewAnd, - ConstantExpr::getZExt(RHS, Cast->getSrcTy())); - } - } - - // If the LHS is an AND of a zext, and we have an equality compare, we can - // shrink the and/compare to the smaller type, eliminating the cast. - if (ZExtInst *Cast = dyn_cast<ZExtInst>(LHSI->getOperand(0))) { - IntegerType *Ty = cast<IntegerType>(Cast->getSrcTy()); - // Make sure we don't compare the upper bits, SimplifyDemandedBits - // should fold the icmp to true/false in that case. - if (ICI.isEquality() && RHSV.getActiveBits() <= Ty->getBitWidth()) { - Value *NewAnd = - Builder->CreateAnd(Cast->getOperand(0), - ConstantExpr::getTrunc(AndCst, Ty)); - NewAnd->takeName(LHSI); - return new ICmpInst(ICI.getPredicate(), NewAnd, - ConstantExpr::getTrunc(RHS, Ty)); - } - } - - // If this is: (X >> C1) & C2 != C3 (where any shift and any compare - // could exist), turn it into (X & (C2 << C1)) != (C3 << C1). This - // happens a LOT in code produced by the C front-end, for bitfield - // access. - BinaryOperator *Shift = dyn_cast<BinaryOperator>(LHSI->getOperand(0)); - if (Shift && !Shift->isShift()) - Shift = nullptr; - - ConstantInt *ShAmt; - ShAmt = Shift ? dyn_cast<ConstantInt>(Shift->getOperand(1)) : nullptr; - - // This seemingly simple opportunity to fold away a shift turns out to - // be rather complicated. See PR17827 - // ( http://llvm.org/bugs/show_bug.cgi?id=17827 ) for details. - if (ShAmt) { - bool CanFold = false; - unsigned ShiftOpcode = Shift->getOpcode(); - if (ShiftOpcode == Instruction::AShr) { - // There may be some constraints that make this possible, - // but nothing simple has been discovered yet. - CanFold = false; - } else if (ShiftOpcode == Instruction::Shl) { - // For a left shift, we can fold if the comparison is not signed. - // We can also fold a signed comparison if the mask value and - // comparison value are not negative. These constraints may not be - // obvious, but we can prove that they are correct using an SMT - // solver. - if (!ICI.isSigned() || (!AndCst->isNegative() && !RHS->isNegative())) - CanFold = true; - } else if (ShiftOpcode == Instruction::LShr) { - // For a logical right shift, we can fold if the comparison is not - // signed. We can also fold a signed comparison if the shifted mask - // value and the shifted comparison value are not negative. - // These constraints may not be obvious, but we can prove that they - // are correct using an SMT solver. - if (!ICI.isSigned()) - CanFold = true; - else { - ConstantInt *ShiftedAndCst = - cast<ConstantInt>(ConstantExpr::getShl(AndCst, ShAmt)); - ConstantInt *ShiftedRHSCst = - cast<ConstantInt>(ConstantExpr::getShl(RHS, ShAmt)); - - if (!ShiftedAndCst->isNegative() && !ShiftedRHSCst->isNegative()) - CanFold = true; - } - } + return new ICmpInst(Pred, Y, Z); + } - if (CanFold) { - Constant *NewCst; - if (ShiftOpcode == Instruction::Shl) - NewCst = ConstantExpr::getLShr(RHS, ShAmt); - else - NewCst = ConstantExpr::getShl(RHS, ShAmt); - - // Check to see if we are shifting out any of the bits being - // compared. - if (ConstantExpr::get(ShiftOpcode, NewCst, ShAmt) != RHS) { - // If we shifted bits out, the fold is not going to work out. - // As a special case, check to see if this means that the - // result is always true or false now. - if (ICI.getPredicate() == ICmpInst::ICMP_EQ) - return replaceInstUsesWith(ICI, Builder->getFalse()); - if (ICI.getPredicate() == ICmpInst::ICMP_NE) - return replaceInstUsesWith(ICI, Builder->getTrue()); + // icmp slt (X + -1), Y -> icmp sle X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLT && + match(B, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SLE, A, Op1); + + // icmp sge (X + -1), Y -> icmp sgt X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGE && + match(B, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SGT, A, Op1); + + // icmp sle (X + 1), Y -> icmp slt X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLE && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_SLT, A, Op1); + + // icmp sgt (X + 1), Y -> icmp sge X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGT && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_SGE, A, Op1); + + // icmp sgt X, (Y + -1) -> icmp sge X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGT && + match(D, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SGE, Op0, C); + + // icmp sle X, (Y + -1) -> icmp slt X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLE && + match(D, m_AllOnes())) + return new ICmpInst(CmpInst::ICMP_SLT, Op0, C); + + // icmp sge X, (Y + 1) -> icmp sgt X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGE && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_SGT, Op0, C); + + // icmp slt X, (Y + 1) -> icmp sle X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_SLE, Op0, C); + + // if C1 has greater magnitude than C2: + // icmp (X + C1), (Y + C2) -> icmp (X + C3), Y + // s.t. C3 = C1 - C2 + // + // if C2 has greater magnitude than C1: + // icmp (X + C1), (Y + C2) -> icmp X, (Y + C3) + // s.t. C3 = C2 - C1 + if (A && C && NoOp0WrapProblem && NoOp1WrapProblem && + (BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned()) + if (ConstantInt *C1 = dyn_cast<ConstantInt>(B)) + if (ConstantInt *C2 = dyn_cast<ConstantInt>(D)) { + const APInt &AP1 = C1->getValue(); + const APInt &AP2 = C2->getValue(); + if (AP1.isNegative() == AP2.isNegative()) { + APInt AP1Abs = C1->getValue().abs(); + APInt AP2Abs = C2->getValue().abs(); + if (AP1Abs.uge(AP2Abs)) { + ConstantInt *C3 = Builder->getInt(AP1 - AP2); + Value *NewAdd = Builder->CreateNSWAdd(A, C3); + return new ICmpInst(Pred, NewAdd, C); } else { - ICI.setOperand(1, NewCst); - Constant *NewAndCst; - if (ShiftOpcode == Instruction::Shl) - NewAndCst = ConstantExpr::getLShr(AndCst, ShAmt); - else - NewAndCst = ConstantExpr::getShl(AndCst, ShAmt); - LHSI->setOperand(1, NewAndCst); - LHSI->setOperand(0, Shift->getOperand(0)); - Worklist.Add(Shift); // Shift is dead. - return &ICI; + ConstantInt *C3 = Builder->getInt(AP2 - AP1); + Value *NewAdd = Builder->CreateNSWAdd(C, C3); + return new ICmpInst(Pred, A, NewAdd); } } } - // Turn ((X >> Y) & C) == 0 into (X & (C << Y)) == 0. The later is - // preferable because it allows the C<<Y expression to be hoisted out - // of a loop if Y is invariant and X is not. - if (Shift && Shift->hasOneUse() && RHSV == 0 && - ICI.isEquality() && !Shift->isArithmeticShift() && - !isa<Constant>(Shift->getOperand(0))) { - // Compute C << Y. - Value *NS; - if (Shift->getOpcode() == Instruction::LShr) { - NS = Builder->CreateShl(AndCst, Shift->getOperand(1)); - } else { - // Insert a logical shift. - NS = Builder->CreateLShr(AndCst, Shift->getOperand(1)); - } + // Analyze the case when either Op0 or Op1 is a sub instruction. + // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null). + A = nullptr; + B = nullptr; + C = nullptr; + D = nullptr; + if (BO0 && BO0->getOpcode() == Instruction::Sub) { + A = BO0->getOperand(0); + B = BO0->getOperand(1); + } + if (BO1 && BO1->getOpcode() == Instruction::Sub) { + C = BO1->getOperand(0); + D = BO1->getOperand(1); + } - // Compute X & (C << Y). - Value *NewAnd = - Builder->CreateAnd(Shift->getOperand(0), NS, LHSI->getName()); + // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow. + if (A == Op1 && NoOp0WrapProblem) + return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B); + + // icmp X, (X-Y) -> icmp Y, 0 for equalities or if there is no overflow. + if (C == Op0 && NoOp1WrapProblem) + return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType())); + + // icmp (Y-X), (Z-X) -> icmp Y, Z for equalities or if there is no overflow. + if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem && + // Try not to increase register pressure. + BO0->hasOneUse() && BO1->hasOneUse()) + return new ICmpInst(Pred, A, C); + + // icmp (X-Y), (X-Z) -> icmp Z, Y for equalities or if there is no overflow. + if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem && + // Try not to increase register pressure. + BO0->hasOneUse() && BO1->hasOneUse()) + return new ICmpInst(Pred, D, B); + + // icmp (0-X) < cst --> x > -cst + if (NoOp0WrapProblem && ICmpInst::isSigned(Pred)) { + Value *X; + if (match(BO0, m_Neg(m_Value(X)))) + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) + if (!RHSC->isMinValue(/*isSigned=*/true)) + return new ICmpInst(I.getSwappedPredicate(), X, + ConstantExpr::getNeg(RHSC)); + } - ICI.setOperand(0, NewAnd); - return &ICI; - } + BinaryOperator *SRem = nullptr; + // icmp (srem X, Y), Y + if (BO0 && BO0->getOpcode() == Instruction::SRem && Op1 == BO0->getOperand(1)) + SRem = BO0; + // icmp Y, (srem X, Y) + else if (BO1 && BO1->getOpcode() == Instruction::SRem && + Op0 == BO1->getOperand(1)) + SRem = BO1; + if (SRem) { + // We don't check hasOneUse to avoid increasing register pressure because + // the value we use is the same value this instruction was already using. + switch (SRem == BO0 ? ICmpInst::getSwappedPredicate(Pred) : Pred) { + default: + break; + case ICmpInst::ICMP_EQ: + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + case ICmpInst::ICMP_NE: + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: + return new ICmpInst(ICmpInst::ICMP_SGT, SRem->getOperand(1), + Constant::getAllOnesValue(SRem->getType())); + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: + return new ICmpInst(ICmpInst::ICMP_SLT, SRem->getOperand(1), + Constant::getNullValue(SRem->getType())); + } + } - // (icmp pred (and (or (lshr X, Y), X), 1), 0) --> - // (icmp pred (and X, (or (shl 1, Y), 1), 0)) - // - // iff pred isn't signed - { - Value *X, *Y, *LShr; - if (!ICI.isSigned() && RHSV == 0) { - if (match(LHSI->getOperand(1), m_One())) { - Constant *One = cast<Constant>(LHSI->getOperand(1)); - Value *Or = LHSI->getOperand(0); - if (match(Or, m_Or(m_Value(LShr), m_Value(X))) && - match(LShr, m_LShr(m_Specific(X), m_Value(Y)))) { - unsigned UsesRemoved = 0; - if (LHSI->hasOneUse()) - ++UsesRemoved; - if (Or->hasOneUse()) - ++UsesRemoved; - if (LShr->hasOneUse()) - ++UsesRemoved; - Value *NewOr = nullptr; - // Compute X & ((1 << Y) | 1) - if (auto *C = dyn_cast<Constant>(Y)) { - if (UsesRemoved >= 1) - NewOr = - ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One); - } else { - if (UsesRemoved >= 3) - NewOr = Builder->CreateOr(Builder->CreateShl(One, Y, - LShr->getName(), - /*HasNUW=*/true), - One, Or->getName()); - } - if (NewOr) { - Value *NewAnd = Builder->CreateAnd(X, NewOr, LHSI->getName()); - ICI.setOperand(0, NewAnd); - return &ICI; - } - } - } + if (BO0 && BO1 && BO0->getOpcode() == BO1->getOpcode() && BO0->hasOneUse() && + BO1->hasOneUse() && BO0->getOperand(1) == BO1->getOperand(1)) { + switch (BO0->getOpcode()) { + default: + break; + case Instruction::Add: + case Instruction::Sub: + case Instruction::Xor: + if (I.isEquality()) // a+x icmp eq/ne b+x --> a icmp b + return new ICmpInst(I.getPredicate(), BO0->getOperand(0), + BO1->getOperand(0)); + // icmp u/s (a ^ signbit), (b ^ signbit) --> icmp s/u a, b + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { + if (CI->getValue().isSignBit()) { + ICmpInst::Predicate Pred = + I.isSigned() ? I.getUnsignedPredicate() : I.getSignedPredicate(); + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); } - } - // Replace ((X & AndCst) > RHSV) with ((X & AndCst) != 0), if any - // bit set in (X & AndCst) will produce a result greater than RHSV. - if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { - unsigned NTZ = AndCst->getValue().countTrailingZeros(); - if ((NTZ < AndCst->getBitWidth()) && - APInt::getOneBitSet(AndCst->getBitWidth(), NTZ).ugt(RHSV)) - return new ICmpInst(ICmpInst::ICMP_NE, LHSI, - Constant::getNullValue(RHS->getType())); + if (BO0->getOpcode() == Instruction::Xor && CI->isMaxValue(true)) { + ICmpInst::Predicate Pred = + I.isSigned() ? I.getUnsignedPredicate() : I.getSignedPredicate(); + Pred = I.getSwappedPredicate(Pred); + return new ICmpInst(Pred, BO0->getOperand(0), BO1->getOperand(0)); + } } - } - - // Try to optimize things like "A[i]&42 == 0" to index computations. - if (LoadInst *LI = dyn_cast<LoadInst>(LHSI->getOperand(0))) { - if (GetElementPtrInst *GEP = - dyn_cast<GetElementPtrInst>(LI->getOperand(0))) - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) - if (GV->isConstant() && GV->hasDefinitiveInitializer() && - !LI->isVolatile() && isa<ConstantInt>(LHSI->getOperand(1))) { - ConstantInt *C = cast<ConstantInt>(LHSI->getOperand(1)); - if (Instruction *Res = FoldCmpLoadFromIndexedGlobal(GEP, GV,ICI, C)) - return Res; - } - } + break; + case Instruction::Mul: + if (!I.isEquality()) + break; - // X & -C == -C -> X > u ~C - // X & -C != -C -> X <= u ~C - // iff C is a power of 2 - if (ICI.isEquality() && RHS == LHSI->getOperand(1) && (-RHSV).isPowerOf2()) - return new ICmpInst( - ICI.getPredicate() == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_UGT - : ICmpInst::ICMP_ULE, - LHSI->getOperand(0), SubOne(RHS)); - - // (icmp eq (and %A, C), 0) -> (icmp sgt (trunc %A), -1) - // iff C is a power of 2 - if (ICI.isEquality() && LHSI->hasOneUse() && match(RHS, m_Zero())) { - if (auto *CI = dyn_cast<ConstantInt>(LHSI->getOperand(1))) { - const APInt &AI = CI->getValue(); - int32_t ExactLogBase2 = AI.exactLogBase2(); - if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { - Type *NTy = IntegerType::get(ICI.getContext(), ExactLogBase2 + 1); - Value *Trunc = Builder->CreateTrunc(LHSI->getOperand(0), NTy); - return new ICmpInst(ICI.getPredicate() == ICmpInst::ICMP_EQ - ? ICmpInst::ICMP_SGE - : ICmpInst::ICMP_SLT, - Trunc, Constant::getNullValue(NTy)); + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { + // a * Cst icmp eq/ne b * Cst --> a & Mask icmp b & Mask + // Mask = -1 >> count-trailing-zeros(Cst). + if (!CI->isZero() && !CI->isOne()) { + const APInt &AP = CI->getValue(); + ConstantInt *Mask = ConstantInt::get( + I.getContext(), + APInt::getLowBitsSet(AP.getBitWidth(), + AP.getBitWidth() - AP.countTrailingZeros())); + Value *And1 = Builder->CreateAnd(BO0->getOperand(0), Mask); + Value *And2 = Builder->CreateAnd(BO1->getOperand(0), Mask); + return new ICmpInst(I.getPredicate(), And1, And2); } } + break; + case Instruction::UDiv: + case Instruction::LShr: + if (I.isSigned()) + break; + LLVM_FALLTHROUGH; + case Instruction::SDiv: + case Instruction::AShr: + if (!BO0->isExact() || !BO1->isExact()) + break; + return new ICmpInst(I.getPredicate(), BO0->getOperand(0), + BO1->getOperand(0)); + case Instruction::Shl: { + bool NUW = BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap(); + bool NSW = BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap(); + if (!NUW && !NSW) + break; + if (!NSW && I.isSigned()) + break; + return new ICmpInst(I.getPredicate(), BO0->getOperand(0), + BO1->getOperand(0)); } - break; - - case Instruction::Or: { - if (RHS->isOne()) { - // icmp slt signum(V) 1 --> icmp slt V, 1 - Value *V = nullptr; - if (ICI.getPredicate() == ICmpInst::ICMP_SLT && - match(LHSI, m_Signum(m_Value(V)))) - return new ICmpInst(ICmpInst::ICMP_SLT, V, - ConstantInt::get(V->getType(), 1)); } + } - if (!ICI.isEquality() || !RHS->isNullValue() || !LHSI->hasOneUse()) - break; - Value *P, *Q; - if (match(LHSI, m_Or(m_PtrToInt(m_Value(P)), m_PtrToInt(m_Value(Q))))) { - // Simplify icmp eq (or (ptrtoint P), (ptrtoint Q)), 0 - // -> and (icmp eq P, null), (icmp eq Q, null). - Value *ICIP = Builder->CreateICmp(ICI.getPredicate(), P, - Constant::getNullValue(P->getType())); - Value *ICIQ = Builder->CreateICmp(ICI.getPredicate(), Q, - Constant::getNullValue(Q->getType())); - Instruction *Op; - if (ICI.getPredicate() == ICmpInst::ICMP_EQ) - Op = BinaryOperator::CreateAnd(ICIP, ICIQ); - else - Op = BinaryOperator::CreateOr(ICIP, ICIQ); - return Op; + if (BO0) { + // Transform A & (L - 1) `ult` L --> L != 0 + auto LSubOne = m_Add(m_Specific(Op1), m_AllOnes()); + auto BitwiseAnd = + m_CombineOr(m_And(m_Value(), LSubOne), m_And(LSubOne, m_Value())); + + if (match(BO0, BitwiseAnd) && I.getPredicate() == ICmpInst::ICMP_ULT) { + auto *Zero = Constant::getNullValue(BO0->getType()); + return new ICmpInst(ICmpInst::ICMP_NE, Op1, Zero); } - break; } - case Instruction::Mul: { // (icmp pred (mul X, Val), CI) - ConstantInt *Val = dyn_cast<ConstantInt>(LHSI->getOperand(1)); - if (!Val) break; + return nullptr; +} - // If this is a signed comparison to 0 and the mul is sign preserving, - // use the mul LHS operand instead. - ICmpInst::Predicate pred = ICI.getPredicate(); - if (isSignTest(pred, RHS) && !Val->isZero() && - cast<BinaryOperator>(LHSI)->hasNoSignedWrap()) - return new ICmpInst(Val->isNegative() ? - ICmpInst::getSwappedPredicate(pred) : pred, - LHSI->getOperand(0), - Constant::getNullValue(RHS->getType())); +/// Fold icmp Pred min|max(X, Y), X. +static Instruction *foldICmpWithMinMax(ICmpInst &Cmp) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Op0 = Cmp.getOperand(0); + Value *X = Cmp.getOperand(1); + + // Canonicalize minimum or maximum operand to LHS of the icmp. + if (match(X, m_c_SMin(m_Specific(Op0), m_Value())) || + match(X, m_c_SMax(m_Specific(Op0), m_Value())) || + match(X, m_c_UMin(m_Specific(Op0), m_Value())) || + match(X, m_c_UMax(m_Specific(Op0), m_Value()))) { + std::swap(Op0, X); + Pred = Cmp.getSwappedPredicate(); + } - break; + Value *Y; + if (match(Op0, m_c_SMin(m_Specific(X), m_Value(Y)))) { + // smin(X, Y) == X --> X s<= Y + // smin(X, Y) s>= X --> X s<= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SGE) + return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); + + // smin(X, Y) != X --> X s> Y + // smin(X, Y) s< X --> X s> Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SLT) + return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); + + // These cases should be handled in InstSimplify: + // smin(X, Y) s<= X --> true + // smin(X, Y) s> X --> false + return nullptr; } - case Instruction::Shl: { // (icmp pred (shl X, ShAmt), CI) - uint32_t TypeBits = RHSV.getBitWidth(); - ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1)); - if (!ShAmt) { - Value *X; - // (1 << X) pred P2 -> X pred Log2(P2) - if (match(LHSI, m_Shl(m_One(), m_Value(X)))) { - bool RHSVIsPowerOf2 = RHSV.isPowerOf2(); - ICmpInst::Predicate Pred = ICI.getPredicate(); - if (ICI.isUnsigned()) { - if (!RHSVIsPowerOf2) { - // (1 << X) < 30 -> X <= 4 - // (1 << X) <= 30 -> X <= 4 - // (1 << X) >= 30 -> X > 4 - // (1 << X) > 30 -> X > 4 - if (Pred == ICmpInst::ICMP_ULT) - Pred = ICmpInst::ICMP_ULE; - else if (Pred == ICmpInst::ICMP_UGE) - Pred = ICmpInst::ICMP_UGT; - } - unsigned RHSLog2 = RHSV.logBase2(); - - // (1 << X) >= 2147483648 -> X >= 31 -> X == 31 - // (1 << X) < 2147483648 -> X < 31 -> X != 31 - if (RHSLog2 == TypeBits-1) { - if (Pred == ICmpInst::ICMP_UGE) - Pred = ICmpInst::ICMP_EQ; - else if (Pred == ICmpInst::ICMP_ULT) - Pred = ICmpInst::ICMP_NE; - } + if (match(Op0, m_c_SMax(m_Specific(X), m_Value(Y)))) { + // smax(X, Y) == X --> X s>= Y + // smax(X, Y) s<= X --> X s>= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SLE) + return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); - return new ICmpInst(Pred, X, - ConstantInt::get(RHS->getType(), RHSLog2)); - } else if (ICI.isSigned()) { - if (RHSV.isAllOnesValue()) { - // (1 << X) <= -1 -> X == 31 - if (Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); - - // (1 << X) > -1 -> X != 31 - if (Pred == ICmpInst::ICMP_SGT) - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); - } else if (!RHSV) { - // (1 << X) < 0 -> X == 31 - // (1 << X) <= 0 -> X == 31 - if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); - - // (1 << X) >= 0 -> X != 31 - // (1 << X) > 0 -> X != 31 - if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(RHS->getType(), TypeBits-1)); - } - } else if (ICI.isEquality()) { - if (RHSVIsPowerOf2) - return new ICmpInst( - Pred, X, ConstantInt::get(RHS->getType(), RHSV.logBase2())); - } - } - break; - } + // smax(X, Y) != X --> X s< Y + // smax(X, Y) s> X --> X s< Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SGT) + return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); - // Check that the shift amount is in range. If not, don't perform - // undefined shifts. When the shift is visited it will be - // simplified. - if (ShAmt->uge(TypeBits)) - break; + // These cases should be handled in InstSimplify: + // smax(X, Y) s>= X --> true + // smax(X, Y) s< X --> false + return nullptr; + } - if (ICI.isEquality()) { - // If we are comparing against bits always shifted out, the - // comparison cannot succeed. - Constant *Comp = - ConstantExpr::getShl(ConstantExpr::getLShr(RHS, ShAmt), - ShAmt); - if (Comp != RHS) {// Comparing against a bit that we know is zero. - bool IsICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; - Constant *Cst = Builder->getInt1(IsICMP_NE); - return replaceInstUsesWith(ICI, Cst); - } + if (match(Op0, m_c_UMin(m_Specific(X), m_Value(Y)))) { + // umin(X, Y) == X --> X u<= Y + // umin(X, Y) u>= X --> X u<= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_UGE) + return new ICmpInst(ICmpInst::ICMP_ULE, X, Y); - // If the shift is NUW, then it is just shifting out zeros, no need for an - // AND. - if (cast<BinaryOperator>(LHSI)->hasNoUnsignedWrap()) - return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), - ConstantExpr::getLShr(RHS, ShAmt)); - - // If the shift is NSW and we compare to 0, then it is just shifting out - // sign bits, no need for an AND either. - if (cast<BinaryOperator>(LHSI)->hasNoSignedWrap() && RHSV == 0) - return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), - ConstantExpr::getLShr(RHS, ShAmt)); - - if (LHSI->hasOneUse()) { - // Otherwise strength reduce the shift into an and. - uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); - Constant *Mask = Builder->getInt(APInt::getLowBitsSet(TypeBits, - TypeBits - ShAmtVal)); - - Value *And = - Builder->CreateAnd(LHSI->getOperand(0),Mask, LHSI->getName()+".mask"); - return new ICmpInst(ICI.getPredicate(), And, - ConstantExpr::getLShr(RHS, ShAmt)); - } - } + // umin(X, Y) != X --> X u> Y + // umin(X, Y) u< X --> X u> Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT) + return new ICmpInst(ICmpInst::ICMP_UGT, X, Y); - // If this is a signed comparison to 0 and the shift is sign preserving, - // use the shift LHS operand instead. - ICmpInst::Predicate pred = ICI.getPredicate(); - if (isSignTest(pred, RHS) && - cast<BinaryOperator>(LHSI)->hasNoSignedWrap()) - return new ICmpInst(pred, - LHSI->getOperand(0), - Constant::getNullValue(RHS->getType())); - - // Otherwise, if this is a comparison of the sign bit, simplify to and/test. - bool TrueIfSigned = false; - if (LHSI->hasOneUse() && - isSignBitCheck(ICI.getPredicate(), RHS, TrueIfSigned)) { - // (X << 31) <s 0 --> (X&1) != 0 - Constant *Mask = ConstantInt::get(LHSI->getOperand(0)->getType(), - APInt::getOneBitSet(TypeBits, - TypeBits-ShAmt->getZExtValue()-1)); - Value *And = - Builder->CreateAnd(LHSI->getOperand(0), Mask, LHSI->getName()+".mask"); - return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, - And, Constant::getNullValue(And->getType())); - } + // These cases should be handled in InstSimplify: + // umin(X, Y) u<= X --> true + // umin(X, Y) u> X --> false + return nullptr; + } - // Transform (icmp pred iM (shl iM %v, N), CI) - // -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (CI>>N)) - // Transform the shl to a trunc if (trunc (CI>>N)) has no loss and M-N. - // This enables to get rid of the shift in favor of a trunc which can be - // free on the target. It has the additional benefit of comparing to a - // smaller constant, which will be target friendly. - unsigned Amt = ShAmt->getLimitedValue(TypeBits-1); - if (LHSI->hasOneUse() && - Amt != 0 && RHSV.countTrailingZeros() >= Amt) { - Type *NTy = IntegerType::get(ICI.getContext(), TypeBits - Amt); - Constant *NCI = ConstantExpr::getTrunc( - ConstantExpr::getAShr(RHS, - ConstantInt::get(RHS->getType(), Amt)), - NTy); - return new ICmpInst(ICI.getPredicate(), - Builder->CreateTrunc(LHSI->getOperand(0), NTy), - NCI); - } + if (match(Op0, m_c_UMax(m_Specific(X), m_Value(Y)))) { + // umax(X, Y) == X --> X u>= Y + // umax(X, Y) u<= X --> X u>= Y + if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_ULE) + return new ICmpInst(ICmpInst::ICMP_UGE, X, Y); - break; + // umax(X, Y) != X --> X u< Y + // umax(X, Y) u> X --> X u< Y + if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); + + // These cases should be handled in InstSimplify: + // umax(X, Y) u>= X --> true + // umax(X, Y) u< X --> false + return nullptr; } - case Instruction::LShr: // (icmp pred (shr X, ShAmt), CI) - case Instruction::AShr: { - // Handle equality comparisons of shift-by-constant. - BinaryOperator *BO = cast<BinaryOperator>(LHSI); - if (ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1))) { - if (Instruction *Res = FoldICmpShrCst(ICI, BO, ShAmt)) - return Res; - } + return nullptr; +} + +Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { + if (!I.isEquality()) + return nullptr; - // Handle exact shr's. - if (ICI.isEquality() && BO->isExact() && BO->hasOneUse()) { - if (RHSV.isMinValue()) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), RHS); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *A, *B, *C, *D; + if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) { + if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0 + Value *OtherVal = A == Op1 ? B : A; + return new ICmpInst(I.getPredicate(), OtherVal, + Constant::getNullValue(A->getType())); } - break; - } - case Instruction::UDiv: - if (ConstantInt *DivLHS = dyn_cast<ConstantInt>(LHSI->getOperand(0))) { - Value *X = LHSI->getOperand(1); - const APInt &C1 = RHS->getValue(); - const APInt &C2 = DivLHS->getValue(); - assert(C2 != 0 && "udiv 0, X should have been simplified already."); - // (icmp ugt (udiv C2, X), C1) -> (icmp ule X, C2/(C1+1)) - if (ICI.getPredicate() == ICmpInst::ICMP_UGT) { - assert(!C1.isMaxValue() && - "icmp ugt X, UINT_MAX should have been simplified already."); - return new ICmpInst(ICmpInst::ICMP_ULE, X, - ConstantInt::get(X->getType(), C2.udiv(C1 + 1))); - } - // (icmp ult (udiv C2, X), C1) -> (icmp ugt X, C2/C1) - if (ICI.getPredicate() == ICmpInst::ICMP_ULT) { - assert(C1 != 0 && "icmp ult X, 0 should have been simplified already."); - return new ICmpInst(ICmpInst::ICMP_UGT, X, - ConstantInt::get(X->getType(), C2.udiv(C1))); + if (match(Op1, m_Xor(m_Value(C), m_Value(D)))) { + // A^c1 == C^c2 --> A == C^(c1^c2) + ConstantInt *C1, *C2; + if (match(B, m_ConstantInt(C1)) && match(D, m_ConstantInt(C2)) && + Op1->hasOneUse()) { + Constant *NC = Builder->getInt(C1->getValue() ^ C2->getValue()); + Value *Xor = Builder->CreateXor(C, NC); + return new ICmpInst(I.getPredicate(), A, Xor); } + + // A^B == A^D -> B == D + if (A == C) + return new ICmpInst(I.getPredicate(), B, D); + if (A == D) + return new ICmpInst(I.getPredicate(), B, C); + if (B == C) + return new ICmpInst(I.getPredicate(), A, D); + if (B == D) + return new ICmpInst(I.getPredicate(), A, C); } - // fall-through - case Instruction::SDiv: - // Fold: icmp pred ([us]div X, C1), C2 -> range test - // Fold this div into the comparison, producing a range check. - // Determine, based on the divide type, what the range is being - // checked. If there is an overflow on the low or high side, remember - // it, otherwise compute the range [low, hi) bounding the new value. - // See: InsertRangeTest above for the kinds of replacements possible. - if (ConstantInt *DivRHS = dyn_cast<ConstantInt>(LHSI->getOperand(1))) - if (Instruction *R = FoldICmpDivCst(ICI, cast<BinaryOperator>(LHSI), - DivRHS)) - return R; - break; + } - case Instruction::Sub: { - ConstantInt *LHSC = dyn_cast<ConstantInt>(LHSI->getOperand(0)); - if (!LHSC) break; - const APInt &LHSV = LHSC->getValue(); - - // C1-X <u C2 -> (X|(C2-1)) == C1 - // iff C1 & (C2-1) == C2-1 - // C2 is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() && - RHSV.isPowerOf2() && (LHSV & (RHSV - 1)) == (RHSV - 1)) - return new ICmpInst(ICmpInst::ICMP_EQ, - Builder->CreateOr(LHSI->getOperand(1), RHSV - 1), - LHSC); - - // C1-X >u C2 -> (X|C2) != C1 - // iff C1 & C2 == C2 - // C2+1 is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() && - (RHSV + 1).isPowerOf2() && (LHSV & RHSV) == RHSV) - return new ICmpInst(ICmpInst::ICMP_NE, - Builder->CreateOr(LHSI->getOperand(1), RHSV), LHSC); - break; + if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && (A == Op0 || B == Op0)) { + // A == (A^B) -> B == 0 + Value *OtherVal = A == Op0 ? B : A; + return new ICmpInst(I.getPredicate(), OtherVal, + Constant::getNullValue(A->getType())); } - case Instruction::Add: - // Fold: icmp pred (add X, C1), C2 - if (!ICI.isEquality()) { - ConstantInt *LHSC = dyn_cast<ConstantInt>(LHSI->getOperand(1)); - if (!LHSC) break; - const APInt &LHSV = LHSC->getValue(); + // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 + if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B)))) && + match(Op1, m_OneUse(m_And(m_Value(C), m_Value(D))))) { + Value *X = nullptr, *Y = nullptr, *Z = nullptr; + + if (A == C) { + X = B; + Y = D; + Z = A; + } else if (A == D) { + X = B; + Y = C; + Z = A; + } else if (B == C) { + X = A; + Y = D; + Z = B; + } else if (B == D) { + X = A; + Y = C; + Z = B; + } - ConstantRange CR = ICI.makeConstantRange(ICI.getPredicate(), RHSV) - .subtract(LHSV); + if (X) { // Build (X^Y) & Z + Op1 = Builder->CreateXor(X, Y); + Op1 = Builder->CreateAnd(Op1, Z); + I.setOperand(0, Op1); + I.setOperand(1, Constant::getNullValue(Op1->getType())); + return &I; + } + } - if (ICI.isSigned()) { - if (CR.getLower().isSignBit()) { - return new ICmpInst(ICmpInst::ICMP_SLT, LHSI->getOperand(0), - Builder->getInt(CR.getUpper())); - } else if (CR.getUpper().isSignBit()) { - return new ICmpInst(ICmpInst::ICMP_SGE, LHSI->getOperand(0), - Builder->getInt(CR.getLower())); - } - } else { - if (CR.getLower().isMinValue()) { - return new ICmpInst(ICmpInst::ICMP_ULT, LHSI->getOperand(0), - Builder->getInt(CR.getUpper())); - } else if (CR.getUpper().isMinValue()) { - return new ICmpInst(ICmpInst::ICMP_UGE, LHSI->getOperand(0), - Builder->getInt(CR.getLower())); - } - } + // Transform (zext A) == (B & (1<<X)-1) --> A == (trunc B) + // and (B & (1<<X)-1) == (zext A) --> A == (trunc B) + ConstantInt *Cst1; + if ((Op0->hasOneUse() && match(Op0, m_ZExt(m_Value(A))) && + match(Op1, m_And(m_Value(B), m_ConstantInt(Cst1)))) || + (Op1->hasOneUse() && match(Op0, m_And(m_Value(B), m_ConstantInt(Cst1))) && + match(Op1, m_ZExt(m_Value(A))))) { + APInt Pow2 = Cst1->getValue() + 1; + if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) && + Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth()) + return new ICmpInst(I.getPredicate(), A, + Builder->CreateTrunc(B, A->getType())); + } - // X-C1 <u C2 -> (X & -C2) == C1 - // iff C1 & (C2-1) == 0 - // C2 is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_ULT && LHSI->hasOneUse() && - RHSV.isPowerOf2() && (LHSV & (RHSV - 1)) == 0) - return new ICmpInst(ICmpInst::ICMP_EQ, - Builder->CreateAnd(LHSI->getOperand(0), -RHSV), - ConstantExpr::getNeg(LHSC)); - - // X-C1 >u C2 -> (X & ~C2) != C1 - // iff C1 & C2 == 0 - // C2+1 is a power of 2 - if (ICI.getPredicate() == ICmpInst::ICMP_UGT && LHSI->hasOneUse() && - (RHSV + 1).isPowerOf2() && (LHSV & RHSV) == 0) - return new ICmpInst(ICmpInst::ICMP_NE, - Builder->CreateAnd(LHSI->getOperand(0), ~RHSV), - ConstantExpr::getNeg(LHSC)); + // (A >> C) == (B >> C) --> (A^B) u< (1 << C) + // For lshr and ashr pairs. + if ((match(Op0, m_OneUse(m_LShr(m_Value(A), m_ConstantInt(Cst1)))) && + match(Op1, m_OneUse(m_LShr(m_Value(B), m_Specific(Cst1))))) || + (match(Op0, m_OneUse(m_AShr(m_Value(A), m_ConstantInt(Cst1)))) && + match(Op1, m_OneUse(m_AShr(m_Value(B), m_Specific(Cst1)))))) { + unsigned TypeBits = Cst1->getBitWidth(); + unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); + if (ShAmt < TypeBits && ShAmt != 0) { + ICmpInst::Predicate Pred = I.getPredicate() == ICmpInst::ICMP_NE + ? ICmpInst::ICMP_UGE + : ICmpInst::ICMP_ULT; + Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); + APInt CmpVal = APInt::getOneBitSet(TypeBits, ShAmt); + return new ICmpInst(Pred, Xor, Builder->getInt(CmpVal)); } - break; } - // Simplify icmp_eq and icmp_ne instructions with integer constant RHS. - if (ICI.isEquality()) { - bool isICMP_NE = ICI.getPredicate() == ICmpInst::ICMP_NE; - - // If the first operand is (add|sub|and|or|xor|rem) with a constant, and - // the second operand is a constant, simplify a bit. - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(LHSI)) { - switch (BO->getOpcode()) { - case Instruction::SRem: - // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. - if (RHSV == 0 && isa<ConstantInt>(BO->getOperand(1)) &&BO->hasOneUse()){ - const APInt &V = cast<ConstantInt>(BO->getOperand(1))->getValue(); - if (V.sgt(1) && V.isPowerOf2()) { - Value *NewRem = - Builder->CreateURem(BO->getOperand(0), BO->getOperand(1), - BO->getName()); - return new ICmpInst(ICI.getPredicate(), NewRem, - Constant::getNullValue(BO->getType())); - } - } - break; - case Instruction::Add: - // Replace ((add A, B) != C) with (A != C-B) if B & C are constants. - if (ConstantInt *BOp1C = dyn_cast<ConstantInt>(BO->getOperand(1))) { - if (BO->hasOneUse()) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - ConstantExpr::getSub(RHS, BOp1C)); - } else if (RHSV == 0) { - // Replace ((add A, B) != 0) with (A != -B) if A or B is - // efficiently invertible, or if the add has just this one use. - Value *BOp0 = BO->getOperand(0), *BOp1 = BO->getOperand(1); - - if (Value *NegVal = dyn_castNegVal(BOp1)) - return new ICmpInst(ICI.getPredicate(), BOp0, NegVal); - if (Value *NegVal = dyn_castNegVal(BOp0)) - return new ICmpInst(ICI.getPredicate(), NegVal, BOp1); - if (BO->hasOneUse()) { - Value *Neg = Builder->CreateNeg(BOp1); - Neg->takeName(BO); - return new ICmpInst(ICI.getPredicate(), BOp0, Neg); - } - } - break; - case Instruction::Xor: - if (BO->hasOneUse()) { - if (Constant *BOC = dyn_cast<Constant>(BO->getOperand(1))) { - // For the xor case, we can xor two constants together, eliminating - // the explicit xor. - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - ConstantExpr::getXor(RHS, BOC)); - } else if (RHSV == 0) { - // Replace ((xor A, B) != 0) with (A != B) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - BO->getOperand(1)); - } - } - break; - case Instruction::Sub: - if (BO->hasOneUse()) { - if (ConstantInt *BOp0C = dyn_cast<ConstantInt>(BO->getOperand(0))) { - // Replace ((sub A, B) != C) with (B != A-C) if A & C are constants. - return new ICmpInst(ICI.getPredicate(), BO->getOperand(1), - ConstantExpr::getSub(BOp0C, RHS)); - } else if (RHSV == 0) { - // Replace ((sub A, B) != 0) with (A != B) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - BO->getOperand(1)); - } - } - break; - case Instruction::Or: - // If bits are being or'd in that are not present in the constant we - // are comparing against, then the comparison could never succeed! - if (ConstantInt *BOC = dyn_cast<ConstantInt>(BO->getOperand(1))) { - Constant *NotCI = ConstantExpr::getNot(RHS); - if (!ConstantExpr::getAnd(BOC, NotCI)->isNullValue()) - return replaceInstUsesWith(ICI, Builder->getInt1(isICMP_NE)); - - // Comparing if all bits outside of a constant mask are set? - // Replace (X | C) == -1 with (X & ~C) == ~C. - // This removes the -1 constant. - if (BO->hasOneUse() && RHS->isAllOnesValue()) { - Constant *NotBOC = ConstantExpr::getNot(BOC); - Value *And = Builder->CreateAnd(BO->getOperand(0), NotBOC); - return new ICmpInst(ICI.getPredicate(), And, NotBOC); - } - } - break; + // (A << C) == (B << C) --> ((A^B) & (~0U >> C)) == 0 + if (match(Op0, m_OneUse(m_Shl(m_Value(A), m_ConstantInt(Cst1)))) && + match(Op1, m_OneUse(m_Shl(m_Value(B), m_Specific(Cst1))))) { + unsigned TypeBits = Cst1->getBitWidth(); + unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); + if (ShAmt < TypeBits && ShAmt != 0) { + Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); + APInt AndVal = APInt::getLowBitsSet(TypeBits, TypeBits - ShAmt); + Value *And = Builder->CreateAnd(Xor, Builder->getInt(AndVal), + I.getName() + ".mask"); + return new ICmpInst(I.getPredicate(), And, + Constant::getNullValue(Cst1->getType())); + } + } - case Instruction::And: - if (ConstantInt *BOC = dyn_cast<ConstantInt>(BO->getOperand(1))) { - // If bits are being compared against that are and'd out, then the - // comparison can never succeed! - if ((RHSV & ~BOC->getValue()) != 0) - return replaceInstUsesWith(ICI, Builder->getInt1(isICMP_NE)); - - // If we have ((X & C) == C), turn it into ((X & C) != 0). - if (RHS == BOC && RHSV.isPowerOf2()) - return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : - ICmpInst::ICMP_NE, LHSI, - Constant::getNullValue(RHS->getType())); - - // Don't perform the following transforms if the AND has multiple uses - if (!BO->hasOneUse()) - break; + // Transform "icmp eq (trunc (lshr(X, cst1)), cst" to + // "icmp (and X, mask), cst" + uint64_t ShAmt = 0; + if (Op0->hasOneUse() && + match(Op0, m_Trunc(m_OneUse(m_LShr(m_Value(A), m_ConstantInt(ShAmt))))) && + match(Op1, m_ConstantInt(Cst1)) && + // Only do this when A has multiple uses. This is most important to do + // when it exposes other optimizations. + !A->hasOneUse()) { + unsigned ASize = cast<IntegerType>(A->getType())->getPrimitiveSizeInBits(); + + if (ShAmt < ASize) { + APInt MaskV = + APInt::getLowBitsSet(ASize, Op0->getType()->getPrimitiveSizeInBits()); + MaskV <<= ShAmt; - // Replace (and X, (1 << size(X)-1) != 0) with x s< 0 - if (BOC->getValue().isSignBit()) { - Value *X = BO->getOperand(0); - Constant *Zero = Constant::getNullValue(X->getType()); - ICmpInst::Predicate pred = isICMP_NE ? - ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; - return new ICmpInst(pred, X, Zero); - } + APInt CmpV = Cst1->getValue().zext(ASize); + CmpV <<= ShAmt; - // ((X & ~7) == 0) --> X < 8 - if (RHSV == 0 && isHighOnes(BOC)) { - Value *X = BO->getOperand(0); - Constant *NegX = ConstantExpr::getNeg(BOC); - ICmpInst::Predicate pred = isICMP_NE ? - ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; - return new ICmpInst(pred, X, NegX); - } - } - break; - case Instruction::Mul: - if (RHSV == 0 && BO->hasNoSignedWrap()) { - if (ConstantInt *BOC = dyn_cast<ConstantInt>(BO->getOperand(1))) { - // The trivial case (mul X, 0) is handled by InstSimplify - // General case : (mul X, C) != 0 iff X != 0 - // (mul X, C) == 0 iff X == 0 - if (!BOC->isZero()) - return new ICmpInst(ICI.getPredicate(), BO->getOperand(0), - Constant::getNullValue(RHS->getType())); - } - } - break; - default: break; - } - } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(LHSI)) { - // Handle icmp {eq|ne} <intrinsic>, intcst. - switch (II->getIntrinsicID()) { - case Intrinsic::bswap: - Worklist.Add(II); - ICI.setOperand(0, II->getArgOperand(0)); - ICI.setOperand(1, Builder->getInt(RHSV.byteSwap())); - return &ICI; - case Intrinsic::ctlz: - case Intrinsic::cttz: - // ctz(A) == bitwidth(a) -> A == 0 and likewise for != - if (RHSV == RHS->getType()->getBitWidth()) { - Worklist.Add(II); - ICI.setOperand(0, II->getArgOperand(0)); - ICI.setOperand(1, ConstantInt::get(RHS->getType(), 0)); - return &ICI; - } - break; - case Intrinsic::ctpop: - // popcount(A) == 0 -> A == 0 and likewise for != - if (RHS->isZero()) { - Worklist.Add(II); - ICI.setOperand(0, II->getArgOperand(0)); - ICI.setOperand(1, RHS); - return &ICI; - } - break; - default: - break; - } + Value *Mask = Builder->CreateAnd(A, Builder->getInt(MaskV)); + return new ICmpInst(I.getPredicate(), Mask, Builder->getInt(CmpV)); } } + return nullptr; } /// Handle icmp (cast x to y), (cast/cst). We only handle extending casts so /// far. -Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICmp) { +Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { const CastInst *LHSCI = cast<CastInst>(ICmp.getOperand(0)); Value *LHSCIOp = LHSCI->getOperand(0); Type *SrcTy = LHSCIOp->getType(); @@ -2485,92 +3384,6 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICmp) { return BinaryOperator::CreateNot(Result); } -/// The caller has matched a pattern of the form: -/// I = icmp ugt (add (add A, B), CI2), CI1 -/// If this is of the form: -/// sum = a + b -/// if (sum+128 >u 255) -/// Then replace it with llvm.sadd.with.overflow.i8. -/// -static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, - ConstantInt *CI2, ConstantInt *CI1, - InstCombiner &IC) { - // The transformation we're trying to do here is to transform this into an - // llvm.sadd.with.overflow. To do this, we have to replace the original add - // with a narrower add, and discard the add-with-constant that is part of the - // range check (if we can't eliminate it, this isn't profitable). - - // In order to eliminate the add-with-constant, the compare can be its only - // use. - Instruction *AddWithCst = cast<Instruction>(I.getOperand(0)); - if (!AddWithCst->hasOneUse()) return nullptr; - - // If CI2 is 2^7, 2^15, 2^31, then it might be an sadd.with.overflow. - if (!CI2->getValue().isPowerOf2()) return nullptr; - unsigned NewWidth = CI2->getValue().countTrailingZeros(); - if (NewWidth != 7 && NewWidth != 15 && NewWidth != 31) return nullptr; - - // The width of the new add formed is 1 more than the bias. - ++NewWidth; - - // Check to see that CI1 is an all-ones value with NewWidth bits. - if (CI1->getBitWidth() == NewWidth || - CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth)) - return nullptr; - - // This is only really a signed overflow check if the inputs have been - // sign-extended; check for that condition. For example, if CI2 is 2^31 and - // the operands of the add are 64 bits wide, we need at least 33 sign bits. - unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; - if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || - IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) - return nullptr; - - // In order to replace the original add with a narrower - // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant - // and truncates that discard the high bits of the add. Verify that this is - // the case. - Instruction *OrigAdd = cast<Instruction>(AddWithCst->getOperand(0)); - for (User *U : OrigAdd->users()) { - if (U == AddWithCst) continue; - - // Only accept truncates for now. We would really like a nice recursive - // predicate like SimplifyDemandedBits, but which goes downwards the use-def - // chain to see which bits of a value are actually demanded. If the - // original add had another add which was then immediately truncated, we - // could still do the transformation. - TruncInst *TI = dyn_cast<TruncInst>(U); - if (!TI || TI->getType()->getPrimitiveSizeInBits() > NewWidth) - return nullptr; - } - - // If the pattern matches, truncate the inputs to the narrower type and - // use the sadd_with_overflow intrinsic to efficiently compute both the - // result and the overflow bit. - Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth); - Value *F = Intrinsic::getDeclaration(I.getModule(), - Intrinsic::sadd_with_overflow, NewType); - - InstCombiner::BuilderTy *Builder = IC.Builder; - - // Put the new code above the original add, in case there are any uses of the - // add between the add and the compare. - Builder->SetInsertPoint(OrigAdd); - - Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName()+".trunc"); - Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName()+".trunc"); - CallInst *Call = Builder->CreateCall(F, {TruncA, TruncB}, "sadd"); - Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result"); - Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType()); - - // The inner add was the result of the narrow add, zero extended to the - // wider type. Replace it with the result computed by the intrinsic. - IC.replaceInstUsesWith(*OrigAdd, ZExt); - - // The original icmp gets replaced with the overflow value. - return ExtractValueInst::Create(Call, 1, "sadd.overflow"); -} - bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, Value *RHS, Instruction &OrigI, Value *&Result, Constant *&Overflow) { @@ -2603,8 +3416,10 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, if (OR == OverflowResult::AlwaysOverflows) return SetResult(Builder->CreateAdd(LHS, RHS), Builder->getTrue(), true); + + // Fall through uadd into sadd + LLVM_FALLTHROUGH; } - // FALL THROUGH uadd into sadd case OCF_SIGNED_ADD: { // X + 0 -> {X, false} if (match(RHS, m_Zero())) @@ -2644,7 +3459,8 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, true); if (OR == OverflowResult::AlwaysOverflows) return SetResult(Builder->CreateMul(LHS, RHS), Builder->getTrue(), true); - } // FALL THROUGH + LLVM_FALLTHROUGH; + } case OCF_SIGNED_MUL: // X * undef -> undef if (isa<UndefValue>(RHS)) @@ -2682,7 +3498,7 @@ bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, /// \param OtherVal The other argument of compare instruction. /// \returns Instruction which must replace the compare instruction, NULL if no /// replacement required. -static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, +static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, Value *OtherVal, InstCombiner &IC) { // Don't bother doing this transformation for pointers, don't do it for // vectors. @@ -2906,8 +3722,8 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal, /// When performing a comparison against a constant, it is possible that not all /// the bits in the LHS are demanded. This helper method computes the mask that /// IS demanded. -static APInt DemandedBitsLHSMask(ICmpInst &I, - unsigned BitWidth, bool isSignCheck) { +static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth, + bool isSignCheck) { if (isSignCheck) return APInt::getSignBit(BitWidth); @@ -2981,7 +3797,7 @@ static bool swapMayExposeCSEOpportunities(const Value * Op0, } /// \brief Check that one use is in the same block as the definition and all -/// other uses are in blocks dominated by a given block +/// other uses are in blocks dominated by a given block. /// /// \param DI Definition /// \param UI Use @@ -2994,21 +3810,18 @@ bool InstCombiner::dominatesAllUses(const Instruction *DI, const Instruction *UI, const BasicBlock *DB) const { assert(DI && UI && "Instruction not defined\n"); - // ignore incomplete definitions + // Ignore incomplete definitions. if (!DI->getParent()) return false; - // DI and UI must be in the same block + // DI and UI must be in the same block. if (DI->getParent() != UI->getParent()) return false; - // Protect from self-referencing blocks + // Protect from self-referencing blocks. if (DI->getParent() == DB) return false; - // DominatorTree available? - if (!DT) - return false; for (const User *U : DI->users()) { auto *Usr = cast<Instruction>(U); - if (Usr != UI && !DT->dominates(DB, Usr->getParent())) + if (Usr != UI && !DT.dominates(DB, Usr->getParent())) return false; } return true; @@ -3067,8 +3880,7 @@ static bool isChainSelectCmpBranch(const SelectInst *SI) { /// are equal, the optimization can work only for EQ predicates. This is not a /// major restriction since a NE compare should be 'normalized' to an equal /// compare, which usually happens in the combiner and test case -/// select-cmp-br.ll -/// checks for it. +/// select-cmp-br.ll checks for it. bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, const ICmpInst *Icmp, const unsigned SIOpd) { @@ -3076,7 +3888,7 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) { BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1); // The check for the unique predecessor is not the best that can be - // done. But it protects efficiently against cases like when SI's + // done. But it protects efficiently against cases like when SI's // home block has two successors, Succ and Succ1, and Succ1 predecessor // of Succ. Then SI can't be replaced by SIOpd because the use that gets // replaced can be reached on either path. So the uniqueness check @@ -3093,6 +3905,229 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, return false; } +/// Try to fold the comparison based on range information we can get by checking +/// whether bits are known to be zero or one in the inputs. +Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = Op0->getType(); + ICmpInst::Predicate Pred = I.getPredicate(); + + // Get scalar or pointer size. + unsigned BitWidth = Ty->isIntOrIntVectorTy() + ? Ty->getScalarSizeInBits() + : DL.getTypeSizeInBits(Ty->getScalarType()); + + if (!BitWidth) + return nullptr; + + // If this is a normal comparison, it demands all bits. If it is a sign bit + // comparison, it only demands the sign bit. + bool IsSignBit = false; + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + bool UnusedBit; + IsSignBit = isSignBitCheck(Pred, *CmpC, UnusedBit); + } + + APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0); + APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); + + if (SimplifyDemandedBits(I.getOperandUse(0), + getDemandedBitsLHSMask(I, BitWidth, IsSignBit), + Op0KnownZero, Op0KnownOne, 0)) + return &I; + + if (SimplifyDemandedBits(I.getOperandUse(1), APInt::getAllOnesValue(BitWidth), + Op1KnownZero, Op1KnownOne, 0)) + return &I; + + // Given the known and unknown bits, compute a range that the LHS could be + // in. Compute the Min, Max and RHS values based on the known bits. For the + // EQ and NE we use unsigned values. + APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0); + APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0); + if (I.isSigned()) { + computeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, Op0Min, + Op0Max); + computeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, Op1Min, + Op1Max); + } else { + computeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, Op0Min, + Op0Max); + computeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, Op1Min, + Op1Max); + } + + // If Min and Max are known to be the same, then SimplifyDemandedBits + // figured out that the LHS is a constant. Constant fold this now, so that + // code below can assume that Min != Max. + if (!isa<Constant>(Op0) && Op0Min == Op0Max) + return new ICmpInst(Pred, ConstantInt::get(Op0->getType(), Op0Min), Op1); + if (!isa<Constant>(Op1) && Op1Min == Op1Max) + return new ICmpInst(Pred, Op0, ConstantInt::get(Op1->getType(), Op1Min)); + + // Based on the range information we know about the LHS, see if we can + // simplify this comparison. For example, (x&4) < 8 is always true. + switch (Pred) { + default: + llvm_unreachable("Unknown icmp opcode!"); + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: { + if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) { + return Pred == CmpInst::ICMP_EQ + ? replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())) + : replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + } + + // If all bits are known zero except for one, then we know at most one bit + // is set. If the comparison is against zero, then this is a check to see if + // *that* bit is set. + APInt Op0KnownZeroInverted = ~Op0KnownZero; + if (~Op1KnownZero == 0) { + // If the LHS is an AND with the same constant, look through it. + Value *LHS = nullptr; + const APInt *LHSC; + if (!match(Op0, m_And(m_Value(LHS), m_APInt(LHSC))) || + *LHSC != Op0KnownZeroInverted) + LHS = Op0; + + Value *X; + if (match(LHS, m_Shl(m_One(), m_Value(X)))) { + APInt ValToCheck = Op0KnownZeroInverted; + Type *XTy = X->getType(); + if (ValToCheck.isPowerOf2()) { + // ((1 << X) & 8) == 0 -> X != 3 + // ((1 << X) & 8) != 0 -> X == 3 + auto *CmpC = ConstantInt::get(XTy, ValToCheck.countTrailingZeros()); + auto NewPred = ICmpInst::getInversePredicate(Pred); + return new ICmpInst(NewPred, X, CmpC); + } else if ((++ValToCheck).isPowerOf2()) { + // ((1 << X) & 7) == 0 -> X >= 3 + // ((1 << X) & 7) != 0 -> X < 3 + auto *CmpC = ConstantInt::get(XTy, ValToCheck.countTrailingZeros()); + auto NewPred = + Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGE : CmpInst::ICMP_ULT; + return new ICmpInst(NewPred, X, CmpC); + } + } + + // Check if the LHS is 8 >>u x and the result is a power of 2 like 1. + const APInt *CI; + if (Op0KnownZeroInverted == 1 && + match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) { + // ((8 >>u X) & 1) == 0 -> X != 3 + // ((8 >>u X) & 1) != 0 -> X == 3 + unsigned CmpVal = CI->countTrailingZeros(); + auto NewPred = ICmpInst::getInversePredicate(Pred); + return new ICmpInst(NewPred, X, ConstantInt::get(X->getType(), CmpVal)); + } + } + break; + } + case ICmpInst::ICMP_ULT: { + if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + // A <u C -> A == C-1 if min(A)+1 == C + if (Op1Max == Op0Min + 1) { + Constant *CMinus1 = ConstantInt::get(Op0->getType(), *CmpC - 1); + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, CMinus1); + } + } + break; + } + case ICmpInst::ICMP_UGT: { + if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + + if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + + if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + // A >u C -> A == C+1 if max(a)-1 == C + if (*CmpC == Op0Max - 1) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC + 1)); + } + break; + } + case ICmpInst::ICMP_SLT: + if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + if (Op1Max == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + Builder->getInt(CI->getValue() - 1)); + } + break; + case ICmpInst::ICMP_SGT: + if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + + if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { + if (Op1Min == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + Builder->getInt(CI->getValue() + 1)); + } + break; + case ICmpInst::ICMP_SGE: + assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); + if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + case ICmpInst::ICMP_SLE: + assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); + if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + case ICmpInst::ICMP_UGE: + assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); + if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + case ICmpInst::ICMP_ULE: + assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); + if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + break; + } + + // Turn a signed comparison into an unsigned one if both operands are known to + // have the same sign. + if (I.isSigned() && + ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) || + (Op0KnownOne.isNegative() && Op1KnownOne.isNegative()))) + return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1); + + return nullptr; +} + /// If we have an icmp le or icmp ge instruction with a constant operand, turn /// it into the appropriate icmp lt or icmp gt instruction. This transform /// allows them to be folded in visitICmpInst. @@ -3131,6 +4166,7 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { if (isa<UndefValue>(Elt)) continue; + // Bail out if we can't determine if this constant is min/max or if we // know that this constant is min/max. auto *CI = dyn_cast<ConstantInt>(Elt); @@ -3167,7 +4203,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { } if (Value *V = - SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AC, &I)) + SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, &TLI, &DT, &AC, &I)) return replaceInstUsesWith(I, V); // comparing -val or val with non-zero is the same as just comparing val @@ -3202,28 +4238,28 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { case ICmpInst::ICMP_UGT: std::swap(Op0, Op1); // Change icmp ugt -> icmp ult - // FALL THROUGH + LLVM_FALLTHROUGH; case ICmpInst::ICMP_ULT:{ // icmp ult i1 A, B -> ~A & B Value *Not = Builder->CreateNot(Op0, I.getName() + "tmp"); return BinaryOperator::CreateAnd(Not, Op1); } case ICmpInst::ICMP_SGT: std::swap(Op0, Op1); // Change icmp sgt -> icmp slt - // FALL THROUGH + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SLT: { // icmp slt i1 A, B -> A & ~B Value *Not = Builder->CreateNot(Op1, I.getName() + "tmp"); return BinaryOperator::CreateAnd(Not, Op0); } case ICmpInst::ICMP_UGE: std::swap(Op0, Op1); // Change icmp uge -> icmp ule - // FALL THROUGH + LLVM_FALLTHROUGH; case ICmpInst::ICMP_ULE: { // icmp ule i1 A, B -> ~A | B Value *Not = Builder->CreateNot(Op0, I.getName() + "tmp"); return BinaryOperator::CreateOr(Not, Op1); } case ICmpInst::ICMP_SGE: std::swap(Op0, Op1); // Change icmp sge -> icmp sle - // FALL THROUGH + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SLE: { // icmp sle i1 A, B -> A | ~B Value *Not = Builder->CreateNot(Op1, I.getName() + "tmp"); return BinaryOperator::CreateOr(Not, Op0); @@ -3234,372 +4270,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I)) return NewICmp; - unsigned BitWidth = 0; - if (Ty->isIntOrIntVectorTy()) - BitWidth = Ty->getScalarSizeInBits(); - else // Get pointer size. - BitWidth = DL.getTypeSizeInBits(Ty->getScalarType()); - - bool isSignBit = false; - - // See if we are doing a comparison with a constant. - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - Value *A = nullptr, *B = nullptr; - - // Match the following pattern, which is a common idiom when writing - // overflow-safe integer arithmetic function. The source performs an - // addition in wider type, and explicitly checks for overflow using - // comparisons against INT_MIN and INT_MAX. Simplify this by using the - // sadd_with_overflow intrinsic. - // - // TODO: This could probably be generalized to handle other overflow-safe - // operations if we worked out the formulas to compute the appropriate - // magic constants. - // - // sum = a + b - // if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8 - { - ConstantInt *CI2; // I = icmp ugt (add (add A, B), CI2), CI - if (I.getPredicate() == ICmpInst::ICMP_UGT && - match(Op0, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2)))) - if (Instruction *Res = ProcessUGT_ADDCST_ADD(I, A, B, CI2, CI, *this)) - return Res; - } - - // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) - if (CI->isZero() && I.getPredicate() == ICmpInst::ICMP_SGT) - if (auto *SI = dyn_cast<SelectInst>(Op0)) { - SelectPatternResult SPR = matchSelectPattern(SI, A, B); - if (SPR.Flavor == SPF_SMIN) { - if (isKnownPositive(A, DL)) - return new ICmpInst(I.getPredicate(), B, CI); - if (isKnownPositive(B, DL)) - return new ICmpInst(I.getPredicate(), A, CI); - } - } - - - // The following transforms are only 'worth it' if the only user of the - // subtraction is the icmp. - if (Op0->hasOneUse()) { - // (icmp ne/eq (sub A B) 0) -> (icmp ne/eq A, B) - if (I.isEquality() && CI->isZero() && - match(Op0, m_Sub(m_Value(A), m_Value(B)))) - return new ICmpInst(I.getPredicate(), A, B); - - // (icmp sgt (sub nsw A B), -1) -> (icmp sge A, B) - if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isAllOnesValue() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SGE, A, B); - - // (icmp sgt (sub nsw A B), 0) -> (icmp sgt A, B) - if (I.getPredicate() == ICmpInst::ICMP_SGT && CI->isZero() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SGT, A, B); - - // (icmp slt (sub nsw A B), 0) -> (icmp slt A, B) - if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isZero() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SLT, A, B); - - // (icmp slt (sub nsw A B), 1) -> (icmp sle A, B) - if (I.getPredicate() == ICmpInst::ICMP_SLT && CI->isOne() && - match(Op0, m_NSWSub(m_Value(A), m_Value(B)))) - return new ICmpInst(ICmpInst::ICMP_SLE, A, B); - } - - if (I.isEquality()) { - ConstantInt *CI2; - if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) || - match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) { - // (icmp eq/ne (ashr/lshr const2, A), const1) - if (Instruction *Inst = FoldICmpCstShrCst(I, Op0, A, CI, CI2)) - return Inst; - } - if (match(Op0, m_Shl(m_ConstantInt(CI2), m_Value(A)))) { - // (icmp eq/ne (shl const2, A), const1) - if (Instruction *Inst = FoldICmpCstShlCst(I, Op0, A, CI, CI2)) - return Inst; - } - } - - // If this comparison is a normal comparison, it demands all - // bits, if it is a sign bit comparison, it only demands the sign bit. - bool UnusedBit; - isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit); - - // Canonicalize icmp instructions based on dominating conditions. - BasicBlock *Parent = I.getParent(); - BasicBlock *Dom = Parent->getSinglePredecessor(); - auto *BI = Dom ? dyn_cast<BranchInst>(Dom->getTerminator()) : nullptr; - ICmpInst::Predicate Pred; - BasicBlock *TrueBB, *FalseBB; - ConstantInt *CI2; - if (BI && match(BI, m_Br(m_ICmp(Pred, m_Specific(Op0), m_ConstantInt(CI2)), - TrueBB, FalseBB)) && - TrueBB != FalseBB) { - ConstantRange CR = ConstantRange::makeAllowedICmpRegion(I.getPredicate(), - CI->getValue()); - ConstantRange DominatingCR = - (Parent == TrueBB) - ? ConstantRange::makeExactICmpRegion(Pred, CI2->getValue()) - : ConstantRange::makeExactICmpRegion( - CmpInst::getInversePredicate(Pred), CI2->getValue()); - ConstantRange Intersection = DominatingCR.intersectWith(CR); - ConstantRange Difference = DominatingCR.difference(CR); - if (Intersection.isEmptySet()) - return replaceInstUsesWith(I, Builder->getFalse()); - if (Difference.isEmptySet()) - return replaceInstUsesWith(I, Builder->getTrue()); - // Canonicalizing a sign bit comparison that gets used in a branch, - // pessimizes codegen by generating branch on zero instruction instead - // of a test and branch. So we avoid canonicalizing in such situations - // because test and branch instruction has better branch displacement - // than compare and branch instruction. - if (!isBranchOnSignBitCheck(I, isSignBit) && !I.isEquality()) { - if (auto *AI = Intersection.getSingleElement()) - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Builder->getInt(*AI)); - if (auto *AD = Difference.getSingleElement()) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Builder->getInt(*AD)); - } - } - } - - // See if we can fold the comparison based on range information we can get - // by checking whether bits are known to be zero or one in the input. - if (BitWidth != 0) { - APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0); - APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); - - if (SimplifyDemandedBits(I.getOperandUse(0), - DemandedBitsLHSMask(I, BitWidth, isSignBit), - Op0KnownZero, Op0KnownOne, 0)) - return &I; - if (SimplifyDemandedBits(I.getOperandUse(1), - APInt::getAllOnesValue(BitWidth), Op1KnownZero, - Op1KnownOne, 0)) - return &I; - - // Given the known and unknown bits, compute a range that the LHS could be - // in. Compute the Min, Max and RHS values based on the known bits. For the - // EQ and NE we use unsigned values. - APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0); - APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0); - if (I.isSigned()) { - ComputeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, - Op0Min, Op0Max); - ComputeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, - Op1Min, Op1Max); - } else { - ComputeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne, - Op0Min, Op0Max); - ComputeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne, - Op1Min, Op1Max); - } - - // If Min and Max are known to be the same, then SimplifyDemandedBits - // figured out that the LHS is a constant. Just constant fold this now so - // that code below can assume that Min != Max. - if (!isa<Constant>(Op0) && Op0Min == Op0Max) - return new ICmpInst(I.getPredicate(), - ConstantInt::get(Op0->getType(), Op0Min), Op1); - if (!isa<Constant>(Op1) && Op1Min == Op1Max) - return new ICmpInst(I.getPredicate(), Op0, - ConstantInt::get(Op1->getType(), Op1Min)); - - // Based on the range information we know about the LHS, see if we can - // simplify this comparison. For example, (x&4) < 8 is always true. - switch (I.getPredicate()) { - default: llvm_unreachable("Unknown icmp opcode!"); - case ICmpInst::ICMP_EQ: { - if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - - // If all bits are known zero except for one, then we know at most one - // bit is set. If the comparison is against zero, then this is a check - // to see if *that* bit is set. - APInt Op0KnownZeroInverted = ~Op0KnownZero; - if (~Op1KnownZero == 0) { - // If the LHS is an AND with the same constant, look through it. - Value *LHS = nullptr; - ConstantInt *LHSC = nullptr; - if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || - LHSC->getValue() != Op0KnownZeroInverted) - LHS = Op0; - - // If the LHS is 1 << x, and we know the result is a power of 2 like 8, - // then turn "((1 << x)&8) == 0" into "x != 3". - // or turn "((1 << x)&7) == 0" into "x > 2". - Value *X = nullptr; - if (match(LHS, m_Shl(m_One(), m_Value(X)))) { - APInt ValToCheck = Op0KnownZeroInverted; - if (ValToCheck.isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros(); - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(X->getType(), CmpVal)); - } else if ((++ValToCheck).isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros() - 1; - return new ICmpInst(ICmpInst::ICMP_UGT, X, - ConstantInt::get(X->getType(), CmpVal)); - } - } - - // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, - // then turn "((8 >>u x)&1) == 0" into "x != 3". - const APInt *CI; - if (Op0KnownZeroInverted == 1 && - match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) - return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(X->getType(), - CI->countTrailingZeros())); - } - break; - } - case ICmpInst::ICMP_NE: { - if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max)) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - - // If all bits are known zero except for one, then we know at most one - // bit is set. If the comparison is against zero, then this is a check - // to see if *that* bit is set. - APInt Op0KnownZeroInverted = ~Op0KnownZero; - if (~Op1KnownZero == 0) { - // If the LHS is an AND with the same constant, look through it. - Value *LHS = nullptr; - ConstantInt *LHSC = nullptr; - if (!match(Op0, m_And(m_Value(LHS), m_ConstantInt(LHSC))) || - LHSC->getValue() != Op0KnownZeroInverted) - LHS = Op0; - - // If the LHS is 1 << x, and we know the result is a power of 2 like 8, - // then turn "((1 << x)&8) != 0" into "x == 3". - // or turn "((1 << x)&7) != 0" into "x < 3". - Value *X = nullptr; - if (match(LHS, m_Shl(m_One(), m_Value(X)))) { - APInt ValToCheck = Op0KnownZeroInverted; - if (ValToCheck.isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros(); - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(X->getType(), CmpVal)); - } else if ((++ValToCheck).isPowerOf2()) { - unsigned CmpVal = ValToCheck.countTrailingZeros(); - return new ICmpInst(ICmpInst::ICMP_ULT, X, - ConstantInt::get(X->getType(), CmpVal)); - } - } - - // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, - // then turn "((8 >>u x)&1) != 0" into "x == 3". - const APInt *CI; - if (Op0KnownZeroInverted == 1 && - match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) - return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(X->getType(), - CI->countTrailingZeros())); - } - break; - } - case ICmpInst::ICMP_ULT: - if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Max == Op0Min+1) // A <u C -> A == C-1 if min(A)+1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue()-1)); - - // (x <u 2147483648) -> (x >s -1) -> true if sign bit clear - if (CI->isMinValue(true)) - return new ICmpInst(ICmpInst::ICMP_SGT, Op0, - Constant::getAllOnesValue(Op0->getType())); - } - break; - case ICmpInst::ICMP_UGT: - if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - - if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Min == Op0Max-1) // A >u C -> A == C+1 if max(a)-1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue()+1)); - - // (x >u 2147483647) -> (x <s 0) -> true if sign bit set - if (CI->isMaxValue(true)) - return new ICmpInst(ICmpInst::ICMP_SLT, Op0, - Constant::getNullValue(Op0->getType())); - } - break; - case ICmpInst::ICMP_SLT: - if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Max == Op0Min+1) // A <s C -> A == C-1 if min(A)+1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue()-1)); - } - break; - case ICmpInst::ICMP_SGT: - if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - - if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Min == Op0Max-1) // A >s C -> A == C+1 if max(A)-1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder->getInt(CI->getValue()+1)); - } - break; - case ICmpInst::ICMP_SGE: - assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); - if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - case ICmpInst::ICMP_SLE: - assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); - if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - case ICmpInst::ICMP_UGE: - assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); - if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - case ICmpInst::ICMP_ULE: - assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); - if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - break; - } + if (Instruction *Res = foldICmpWithConstant(I)) + return Res; - // Turn a signed comparison into an unsigned one if both operands - // are known to have the same sign. - if (I.isSigned() && - ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) || - (Op0KnownOne.isNegative() && Op1KnownOne.isNegative()))) - return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1); - } + if (Instruction *Res = foldICmpUsingKnownBits(I)) + return Res; // Test if the ICmpInst instruction is used exclusively by a select as // part of a minimum or maximum operation. If so, refrain from doing @@ -3614,122 +4289,39 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) return nullptr; - // See if we are doing a comparison between a constant and an instruction that - // can be folded into the comparison. - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - Value *A = nullptr, *B = nullptr; - // Since the RHS is a ConstantInt (CI), if the left hand side is an - // instruction, see if that instruction also has constants so that the - // instruction can be folded into the icmp - if (Instruction *LHSI = dyn_cast<Instruction>(Op0)) - if (Instruction *Res = visitICmpInstWithInstAndIntCst(I, LHSI, CI)) - return Res; + // FIXME: We only do this after checking for min/max to prevent infinite + // looping caused by a reverse canonicalization of these patterns for min/max. + // FIXME: The organization of folds is a mess. These would naturally go into + // canonicalizeCmpWithConstant(), but we can't move all of the above folds + // down here after the min/max restriction. + ICmpInst::Predicate Pred = I.getPredicate(); + const APInt *C; + if (match(Op1, m_APInt(C))) { + // For i32: x >u 2147483647 -> x <s 0 -> true if sign bit set + if (Pred == ICmpInst::ICMP_UGT && C->isMaxSignedValue()) { + Constant *Zero = Constant::getNullValue(Op0->getType()); + return new ICmpInst(ICmpInst::ICMP_SLT, Op0, Zero); + } - // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) - if (I.isEquality() && CI->isZero() && - match(Op0, m_UDiv(m_Value(A), m_Value(B)))) { - ICmpInst::Predicate Pred = I.getPredicate() == ICmpInst::ICMP_EQ - ? ICmpInst::ICMP_UGT - : ICmpInst::ICMP_ULE; - return new ICmpInst(Pred, B, A); + // For i32: x <u 2147483648 -> x >s -1 -> true if sign bit clear + if (Pred == ICmpInst::ICMP_ULT && C->isMinSignedValue()) { + Constant *AllOnes = Constant::getAllOnesValue(Op0->getType()); + return new ICmpInst(ICmpInst::ICMP_SGT, Op0, AllOnes); } } - // Handle icmp with constant (but not simple integer constant) RHS - if (Constant *RHSC = dyn_cast<Constant>(Op1)) { - if (Instruction *LHSI = dyn_cast<Instruction>(Op0)) - switch (LHSI->getOpcode()) { - case Instruction::GetElementPtr: - // icmp pred GEP (P, int 0, int 0, int 0), null -> icmp pred P, null - if (RHSC->isNullValue() && - cast<GetElementPtrInst>(LHSI)->hasAllZeroIndices()) - return new ICmpInst(I.getPredicate(), LHSI->getOperand(0), - Constant::getNullValue(LHSI->getOperand(0)->getType())); - break; - case Instruction::PHI: - // Only fold icmp into the PHI if the phi and icmp are in the same - // block. If in the same block, we're encouraging jump threading. If - // not, we are just pessimizing the code by making an i1 phi. - if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; - break; - case Instruction::Select: { - // If either operand of the select is a constant, we can fold the - // comparison into the select arms, which will cause one to be - // constant folded and the select turned into a bitwise or. - Value *Op1 = nullptr, *Op2 = nullptr; - ConstantInt *CI = nullptr; - if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) { - Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); - CI = dyn_cast<ConstantInt>(Op1); - } - if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) { - Op2 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); - CI = dyn_cast<ConstantInt>(Op2); - } - - // We only want to perform this transformation if it will not lead to - // additional code. This is true if either both sides of the select - // fold to a constant (in which case the icmp is replaced with a select - // which will usually simplify) or this is the only user of the - // select (in which case we are trading a select+icmp for a simpler - // select+icmp) or all uses of the select can be replaced based on - // dominance information ("Global cases"). - bool Transform = false; - if (Op1 && Op2) - Transform = true; - else if (Op1 || Op2) { - // Local case - if (LHSI->hasOneUse()) - Transform = true; - // Global cases - else if (CI && !CI->isZero()) - // When Op1 is constant try replacing select with second operand. - // Otherwise Op2 is constant and try replacing select with first - // operand. - Transform = replacedSelectWithOperand(cast<SelectInst>(LHSI), &I, - Op1 ? 2 : 1); - } - if (Transform) { - if (!Op1) - Op1 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(1), - RHSC, I.getName()); - if (!Op2) - Op2 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(2), - RHSC, I.getName()); - return SelectInst::Create(LHSI->getOperand(0), Op1, Op2); - } - break; - } - case Instruction::IntToPtr: - // icmp pred inttoptr(X), null -> icmp pred X, 0 - if (RHSC->isNullValue() && - DL.getIntPtrType(RHSC->getType()) == LHSI->getOperand(0)->getType()) - return new ICmpInst(I.getPredicate(), LHSI->getOperand(0), - Constant::getNullValue(LHSI->getOperand(0)->getType())); - break; + if (Instruction *Res = foldICmpInstWithConstant(I)) + return Res; - case Instruction::Load: - // Try to optimize things like "A[i] > 4" to index computations. - if (GetElementPtrInst *GEP = - dyn_cast<GetElementPtrInst>(LHSI->getOperand(0))) { - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) - if (GV->isConstant() && GV->hasDefinitiveInitializer() && - !cast<LoadInst>(LHSI)->isVolatile()) - if (Instruction *Res = FoldCmpLoadFromIndexedGlobal(GEP, GV, I)) - return Res; - } - break; - } - } + if (Instruction *Res = foldICmpInstWithConstantNotInt(I)) + return Res; // If we can optimize a 'icmp GEP, P' or 'icmp P, GEP', do so now. if (GEPOperator *GEP = dyn_cast<GEPOperator>(Op0)) - if (Instruction *NI = FoldGEPICmp(GEP, Op1, I.getPredicate(), I)) + if (Instruction *NI = foldGEPICmp(GEP, Op1, I.getPredicate(), I)) return NI; if (GEPOperator *GEP = dyn_cast<GEPOperator>(Op1)) - if (Instruction *NI = FoldGEPICmp(GEP, Op0, + if (Instruction *NI = foldGEPICmp(GEP, Op0, ICmpInst::getSwappedPredicate(I.getPredicate()), I)) return NI; @@ -3737,10 +4329,10 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Op0->getType()->isPointerTy() && I.isEquality()) { assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?"); if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op0, DL))) - if (Instruction *New = FoldAllocaCmp(I, Alloca, Op1)) + if (Instruction *New = foldAllocaCmp(I, Alloca, Op1)) return New; if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op1, DL))) - if (Instruction *New = FoldAllocaCmp(I, Alloca, Op0)) + if (Instruction *New = foldAllocaCmp(I, Alloca, Op0)) return New; } @@ -3780,318 +4372,24 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // For generality, we handle any zero-extension of any operand comparison // with a constant or another cast from the same type. if (isa<Constant>(Op1) || isa<CastInst>(Op1)) - if (Instruction *R = visitICmpInstWithCastAndCast(I)) + if (Instruction *R = foldICmpWithCastAndCast(I)) return R; } - // Special logic for binary operators. - BinaryOperator *BO0 = dyn_cast<BinaryOperator>(Op0); - BinaryOperator *BO1 = dyn_cast<BinaryOperator>(Op1); - if (BO0 || BO1) { - CmpInst::Predicate Pred = I.getPredicate(); - bool NoOp0WrapProblem = false, NoOp1WrapProblem = false; - if (BO0 && isa<OverflowingBinaryOperator>(BO0)) - NoOp0WrapProblem = ICmpInst::isEquality(Pred) || - (CmpInst::isUnsigned(Pred) && BO0->hasNoUnsignedWrap()) || - (CmpInst::isSigned(Pred) && BO0->hasNoSignedWrap()); - if (BO1 && isa<OverflowingBinaryOperator>(BO1)) - NoOp1WrapProblem = ICmpInst::isEquality(Pred) || - (CmpInst::isUnsigned(Pred) && BO1->hasNoUnsignedWrap()) || - (CmpInst::isSigned(Pred) && BO1->hasNoSignedWrap()); - - // Analyze the case when either Op0 or Op1 is an add instruction. - // Op0 = A + B (or A and B are null); Op1 = C + D (or C and D are null). - Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; - if (BO0 && BO0->getOpcode() == Instruction::Add) { - A = BO0->getOperand(0); - B = BO0->getOperand(1); - } - if (BO1 && BO1->getOpcode() == Instruction::Add) { - C = BO1->getOperand(0); - D = BO1->getOperand(1); - } - - // icmp (X+cst) < 0 --> X < -cst - if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero())) - if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B)) - if (!RHSC->isMinValue(/*isSigned=*/true)) - return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC)); - - // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. - if ((A == Op1 || B == Op1) && NoOp0WrapProblem) - return new ICmpInst(Pred, A == Op1 ? B : A, - Constant::getNullValue(Op1->getType())); - - // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. - if ((C == Op0 || D == Op0) && NoOp1WrapProblem) - return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()), - C == Op0 ? D : C); - - // icmp (X+Y), (X+Z) -> icmp Y, Z for equalities or if there is no overflow. - if (A && C && (A == C || A == D || B == C || B == D) && - NoOp0WrapProblem && NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) { - // Determine Y and Z in the form icmp (X+Y), (X+Z). - Value *Y, *Z; - if (A == C) { - // C + B == C + D -> B == D - Y = B; - Z = D; - } else if (A == D) { - // D + B == C + D -> B == C - Y = B; - Z = C; - } else if (B == C) { - // A + C == C + D -> A == D - Y = A; - Z = D; - } else { - assert(B == D); - // A + D == C + D -> A == C - Y = A; - Z = C; - } - return new ICmpInst(Pred, Y, Z); - } - - // icmp slt (X + -1), Y -> icmp sle X, Y - if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLT && - match(B, m_AllOnes())) - return new ICmpInst(CmpInst::ICMP_SLE, A, Op1); - - // icmp sge (X + -1), Y -> icmp sgt X, Y - if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGE && - match(B, m_AllOnes())) - return new ICmpInst(CmpInst::ICMP_SGT, A, Op1); - - // icmp sle (X + 1), Y -> icmp slt X, Y - if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLE && - match(B, m_One())) - return new ICmpInst(CmpInst::ICMP_SLT, A, Op1); - - // icmp sgt (X + 1), Y -> icmp sge X, Y - if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGT && - match(B, m_One())) - return new ICmpInst(CmpInst::ICMP_SGE, A, Op1); - - // icmp sgt X, (Y + -1) -> icmp sge X, Y - if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGT && - match(D, m_AllOnes())) - return new ICmpInst(CmpInst::ICMP_SGE, Op0, C); - - // icmp sle X, (Y + -1) -> icmp slt X, Y - if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLE && - match(D, m_AllOnes())) - return new ICmpInst(CmpInst::ICMP_SLT, Op0, C); - - // icmp sge X, (Y + 1) -> icmp sgt X, Y - if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGE && - match(D, m_One())) - return new ICmpInst(CmpInst::ICMP_SGT, Op0, C); - - // icmp slt X, (Y + 1) -> icmp sle X, Y - if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && - match(D, m_One())) - return new ICmpInst(CmpInst::ICMP_SLE, Op0, C); - - // if C1 has greater magnitude than C2: - // icmp (X + C1), (Y + C2) -> icmp (X + C3), Y - // s.t. C3 = C1 - C2 - // - // if C2 has greater magnitude than C1: - // icmp (X + C1), (Y + C2) -> icmp X, (Y + C3) - // s.t. C3 = C2 - C1 - if (A && C && NoOp0WrapProblem && NoOp1WrapProblem && - (BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned()) - if (ConstantInt *C1 = dyn_cast<ConstantInt>(B)) - if (ConstantInt *C2 = dyn_cast<ConstantInt>(D)) { - const APInt &AP1 = C1->getValue(); - const APInt &AP2 = C2->getValue(); - if (AP1.isNegative() == AP2.isNegative()) { - APInt AP1Abs = C1->getValue().abs(); - APInt AP2Abs = C2->getValue().abs(); - if (AP1Abs.uge(AP2Abs)) { - ConstantInt *C3 = Builder->getInt(AP1 - AP2); - Value *NewAdd = Builder->CreateNSWAdd(A, C3); - return new ICmpInst(Pred, NewAdd, C); - } else { - ConstantInt *C3 = Builder->getInt(AP2 - AP1); - Value *NewAdd = Builder->CreateNSWAdd(C, C3); - return new ICmpInst(Pred, A, NewAdd); - } - } - } - - - // Analyze the case when either Op0 or Op1 is a sub instruction. - // Op0 = A - B (or A and B are null); Op1 = C - D (or C and D are null). - A = nullptr; - B = nullptr; - C = nullptr; - D = nullptr; - if (BO0 && BO0->getOpcode() == Instruction::Sub) { - A = BO0->getOperand(0); - B = BO0->getOperand(1); - } - if (BO1 && BO1->getOpcode() == Instruction::Sub) { - C = BO1->getOperand(0); - D = BO1->getOperand(1); - } - - // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow. - if (A == Op1 && NoOp0WrapProblem) - return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B); - - // icmp X, (X-Y) -> icmp Y, 0 for equalities or if there is no overflow. - if (C == Op0 && NoOp1WrapProblem) - return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType())); - - // icmp (Y-X), (Z-X) -> icmp Y, Z for equalities or if there is no overflow. - if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) - return new ICmpInst(Pred, A, C); - - // icmp (X-Y), (X-Z) -> icmp Z, Y for equalities or if there is no overflow. - if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) - return new ICmpInst(Pred, D, B); - - // icmp (0-X) < cst --> x > -cst - if (NoOp0WrapProblem && ICmpInst::isSigned(Pred)) { - Value *X; - if (match(BO0, m_Neg(m_Value(X)))) - if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) - if (!RHSC->isMinValue(/*isSigned=*/true)) - return new ICmpInst(I.getSwappedPredicate(), X, - ConstantExpr::getNeg(RHSC)); - } - - BinaryOperator *SRem = nullptr; - // icmp (srem X, Y), Y - if (BO0 && BO0->getOpcode() == Instruction::SRem && - Op1 == BO0->getOperand(1)) - SRem = BO0; - // icmp Y, (srem X, Y) - else if (BO1 && BO1->getOpcode() == Instruction::SRem && - Op0 == BO1->getOperand(1)) - SRem = BO1; - if (SRem) { - // We don't check hasOneUse to avoid increasing register pressure because - // the value we use is the same value this instruction was already using. - switch (SRem == BO0 ? ICmpInst::getSwappedPredicate(Pred) : Pred) { - default: break; - case ICmpInst::ICMP_EQ: - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - case ICmpInst::ICMP_NE: - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - return new ICmpInst(ICmpInst::ICMP_SGT, SRem->getOperand(1), - Constant::getAllOnesValue(SRem->getType())); - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - return new ICmpInst(ICmpInst::ICMP_SLT, SRem->getOperand(1), - Constant::getNullValue(SRem->getType())); - } - } - - if (BO0 && BO1 && BO0->getOpcode() == BO1->getOpcode() && - BO0->hasOneUse() && BO1->hasOneUse() && - BO0->getOperand(1) == BO1->getOperand(1)) { - switch (BO0->getOpcode()) { - default: break; - case Instruction::Add: - case Instruction::Sub: - case Instruction::Xor: - if (I.isEquality()) // a+x icmp eq/ne b+x --> a icmp b - return new ICmpInst(I.getPredicate(), BO0->getOperand(0), - BO1->getOperand(0)); - // icmp u/s (a ^ signbit), (b ^ signbit) --> icmp s/u a, b - if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { - if (CI->getValue().isSignBit()) { - ICmpInst::Predicate Pred = I.isSigned() - ? I.getUnsignedPredicate() - : I.getSignedPredicate(); - return new ICmpInst(Pred, BO0->getOperand(0), - BO1->getOperand(0)); - } - - if (BO0->getOpcode() == Instruction::Xor && CI->isMaxValue(true)) { - ICmpInst::Predicate Pred = I.isSigned() - ? I.getUnsignedPredicate() - : I.getSignedPredicate(); - Pred = I.getSwappedPredicate(Pred); - return new ICmpInst(Pred, BO0->getOperand(0), - BO1->getOperand(0)); - } - } - break; - case Instruction::Mul: - if (!I.isEquality()) - break; - - if (ConstantInt *CI = dyn_cast<ConstantInt>(BO0->getOperand(1))) { - // a * Cst icmp eq/ne b * Cst --> a & Mask icmp b & Mask - // Mask = -1 >> count-trailing-zeros(Cst). - if (!CI->isZero() && !CI->isOne()) { - const APInt &AP = CI->getValue(); - ConstantInt *Mask = ConstantInt::get(I.getContext(), - APInt::getLowBitsSet(AP.getBitWidth(), - AP.getBitWidth() - - AP.countTrailingZeros())); - Value *And1 = Builder->CreateAnd(BO0->getOperand(0), Mask); - Value *And2 = Builder->CreateAnd(BO1->getOperand(0), Mask); - return new ICmpInst(I.getPredicate(), And1, And2); - } - } - break; - case Instruction::UDiv: - case Instruction::LShr: - if (I.isSigned()) - break; - // fall-through - case Instruction::SDiv: - case Instruction::AShr: - if (!BO0->isExact() || !BO1->isExact()) - break; - return new ICmpInst(I.getPredicate(), BO0->getOperand(0), - BO1->getOperand(0)); - case Instruction::Shl: { - bool NUW = BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap(); - bool NSW = BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap(); - if (!NUW && !NSW) - break; - if (!NSW && I.isSigned()) - break; - return new ICmpInst(I.getPredicate(), BO0->getOperand(0), - BO1->getOperand(0)); - } - } - } - - if (BO0) { - // Transform A & (L - 1) `ult` L --> L != 0 - auto LSubOne = m_Add(m_Specific(Op1), m_AllOnes()); - auto BitwiseAnd = - m_CombineOr(m_And(m_Value(), LSubOne), m_And(LSubOne, m_Value())); + if (Instruction *Res = foldICmpBinOp(I)) + return Res; - if (match(BO0, BitwiseAnd) && I.getPredicate() == ICmpInst::ICMP_ULT) { - auto *Zero = Constant::getNullValue(BO0->getType()); - return new ICmpInst(ICmpInst::ICMP_NE, Op1, Zero); - } - } - } + if (Instruction *Res = foldICmpWithMinMax(I)) + return Res; - { Value *A, *B; + { + Value *A, *B; // Transform (A & ~B) == 0 --> (A & B) != 0 // and (A & ~B) != 0 --> (A & B) == 0 // if A is a power of 2. if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && match(Op1, m_Zero()) && - isKnownToBeAPowerOfTwo(A, DL, false, 0, AC, &I, DT) && I.isEquality()) + isKnownToBeAPowerOfTwo(A, DL, false, 0, &AC, &I, &DT) && I.isEquality()) return new ICmpInst(I.getInversePredicate(), Builder->CreateAnd(A, B), Op1); @@ -4120,149 +4418,17 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // (zext a) * (zext b) --> llvm.umul.with.overflow. if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { - if (Instruction *R = ProcessUMulZExtIdiom(I, Op0, Op1, *this)) + if (Instruction *R = processUMulZExtIdiom(I, Op0, Op1, *this)) return R; } if (match(Op1, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { - if (Instruction *R = ProcessUMulZExtIdiom(I, Op1, Op0, *this)) + if (Instruction *R = processUMulZExtIdiom(I, Op1, Op0, *this)) return R; } } - if (I.isEquality()) { - Value *A, *B, *C, *D; - - if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) { - if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0 - Value *OtherVal = A == Op1 ? B : A; - return new ICmpInst(I.getPredicate(), OtherVal, - Constant::getNullValue(A->getType())); - } - - if (match(Op1, m_Xor(m_Value(C), m_Value(D)))) { - // A^c1 == C^c2 --> A == C^(c1^c2) - ConstantInt *C1, *C2; - if (match(B, m_ConstantInt(C1)) && - match(D, m_ConstantInt(C2)) && Op1->hasOneUse()) { - Constant *NC = Builder->getInt(C1->getValue() ^ C2->getValue()); - Value *Xor = Builder->CreateXor(C, NC); - return new ICmpInst(I.getPredicate(), A, Xor); - } - - // A^B == A^D -> B == D - if (A == C) return new ICmpInst(I.getPredicate(), B, D); - if (A == D) return new ICmpInst(I.getPredicate(), B, C); - if (B == C) return new ICmpInst(I.getPredicate(), A, D); - if (B == D) return new ICmpInst(I.getPredicate(), A, C); - } - } - - if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && - (A == Op0 || B == Op0)) { - // A == (A^B) -> B == 0 - Value *OtherVal = A == Op0 ? B : A; - return new ICmpInst(I.getPredicate(), OtherVal, - Constant::getNullValue(A->getType())); - } - - // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 - if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B)))) && - match(Op1, m_OneUse(m_And(m_Value(C), m_Value(D))))) { - Value *X = nullptr, *Y = nullptr, *Z = nullptr; - - if (A == C) { - X = B; Y = D; Z = A; - } else if (A == D) { - X = B; Y = C; Z = A; - } else if (B == C) { - X = A; Y = D; Z = B; - } else if (B == D) { - X = A; Y = C; Z = B; - } - - if (X) { // Build (X^Y) & Z - Op1 = Builder->CreateXor(X, Y); - Op1 = Builder->CreateAnd(Op1, Z); - I.setOperand(0, Op1); - I.setOperand(1, Constant::getNullValue(Op1->getType())); - return &I; - } - } - - // Transform (zext A) == (B & (1<<X)-1) --> A == (trunc B) - // and (B & (1<<X)-1) == (zext A) --> A == (trunc B) - ConstantInt *Cst1; - if ((Op0->hasOneUse() && - match(Op0, m_ZExt(m_Value(A))) && - match(Op1, m_And(m_Value(B), m_ConstantInt(Cst1)))) || - (Op1->hasOneUse() && - match(Op0, m_And(m_Value(B), m_ConstantInt(Cst1))) && - match(Op1, m_ZExt(m_Value(A))))) { - APInt Pow2 = Cst1->getValue() + 1; - if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) && - Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth()) - return new ICmpInst(I.getPredicate(), A, - Builder->CreateTrunc(B, A->getType())); - } - - // (A >> C) == (B >> C) --> (A^B) u< (1 << C) - // For lshr and ashr pairs. - if ((match(Op0, m_OneUse(m_LShr(m_Value(A), m_ConstantInt(Cst1)))) && - match(Op1, m_OneUse(m_LShr(m_Value(B), m_Specific(Cst1))))) || - (match(Op0, m_OneUse(m_AShr(m_Value(A), m_ConstantInt(Cst1)))) && - match(Op1, m_OneUse(m_AShr(m_Value(B), m_Specific(Cst1)))))) { - unsigned TypeBits = Cst1->getBitWidth(); - unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); - if (ShAmt < TypeBits && ShAmt != 0) { - ICmpInst::Predicate Pred = I.getPredicate() == ICmpInst::ICMP_NE - ? ICmpInst::ICMP_UGE - : ICmpInst::ICMP_ULT; - Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); - APInt CmpVal = APInt::getOneBitSet(TypeBits, ShAmt); - return new ICmpInst(Pred, Xor, Builder->getInt(CmpVal)); - } - } - - // (A << C) == (B << C) --> ((A^B) & (~0U >> C)) == 0 - if (match(Op0, m_OneUse(m_Shl(m_Value(A), m_ConstantInt(Cst1)))) && - match(Op1, m_OneUse(m_Shl(m_Value(B), m_Specific(Cst1))))) { - unsigned TypeBits = Cst1->getBitWidth(); - unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); - if (ShAmt < TypeBits && ShAmt != 0) { - Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); - APInt AndVal = APInt::getLowBitsSet(TypeBits, TypeBits - ShAmt); - Value *And = Builder->CreateAnd(Xor, Builder->getInt(AndVal), - I.getName() + ".mask"); - return new ICmpInst(I.getPredicate(), And, - Constant::getNullValue(Cst1->getType())); - } - } - - // Transform "icmp eq (trunc (lshr(X, cst1)), cst" to - // "icmp (and X, mask), cst" - uint64_t ShAmt = 0; - if (Op0->hasOneUse() && - match(Op0, m_Trunc(m_OneUse(m_LShr(m_Value(A), - m_ConstantInt(ShAmt))))) && - match(Op1, m_ConstantInt(Cst1)) && - // Only do this when A has multiple uses. This is most important to do - // when it exposes other optimizations. - !A->hasOneUse()) { - unsigned ASize =cast<IntegerType>(A->getType())->getPrimitiveSizeInBits(); - - if (ShAmt < ASize) { - APInt MaskV = - APInt::getLowBitsSet(ASize, Op0->getType()->getPrimitiveSizeInBits()); - MaskV <<= ShAmt; - - APInt CmpV = Cst1->getValue().zext(ASize); - CmpV <<= ShAmt; - - Value *Mask = Builder->CreateAnd(A, Builder->getInt(MaskV)); - return new ICmpInst(I.getPredicate(), Mask, Builder->getInt(CmpV)); - } - } - } + if (Instruction *Res = foldICmpEquality(I)) + return Res; // The 'cmpxchg' instruction returns an aggregate containing the old value and // an i1 which indicates whether or not we successfully did the swap. @@ -4284,18 +4450,17 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Value *X; ConstantInt *Cst; // icmp X+Cst, X if (match(Op0, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op1 == X) - return FoldICmpAddOpCst(I, X, Cst, I.getPredicate()); + return foldICmpAddOpConst(I, X, Cst, I.getPredicate()); // icmp X, X+Cst if (match(Op1, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op0 == X) - return FoldICmpAddOpCst(I, X, Cst, I.getSwappedPredicate()); + return foldICmpAddOpConst(I, X, Cst, I.getSwappedPredicate()); } return Changed ? &I : nullptr; } /// Fold fcmp ([us]itofp x, cst) if possible. -Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, - Instruction *LHSI, +Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, Constant *RHSC) { if (!isa<ConstantFP>(RHSC)) return nullptr; const APFloat &RHS = cast<ConstantFP>(RHSC)->getValueAPF(); @@ -4339,21 +4504,21 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I, // This would allow us to handle (fptosi (x >>s 62) to float) if x is i64 f.e. unsigned InputSize = IntTy->getScalarSizeInBits(); - // Following test does NOT adjust InputSize downwards for signed inputs, - // because the most negative value still requires all the mantissa bits + // Following test does NOT adjust InputSize downwards for signed inputs, + // because the most negative value still requires all the mantissa bits // to distinguish it from one less than that value. if ((int)InputSize > MantissaWidth) { // Conversion would lose accuracy. Check if loss can impact comparison. int Exp = ilogb(RHS); if (Exp == APFloat::IEK_Inf) { int MaxExponent = ilogb(APFloat::getLargest(RHS.getSemantics())); - if (MaxExponent < (int)InputSize - !LHSUnsigned) + if (MaxExponent < (int)InputSize - !LHSUnsigned) // Conversion could create infinity. return nullptr; } else { - // Note that if RHS is zero or NaN, then Exp is negative + // Note that if RHS is zero or NaN, then Exp is negative // and first condition is trivially false. - if (MantissaWidth <= Exp && Exp <= (int)InputSize - !LHSUnsigned) + if (MantissaWidth <= Exp && Exp <= (int)InputSize - !LHSUnsigned) // Conversion could affect comparison. return nullptr; } @@ -4547,7 +4712,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, - I.getFastMathFlags(), DL, TLI, DT, AC, &I)) + I.getFastMathFlags(), DL, &TLI, &DT, &AC, &I)) return replaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' @@ -4601,17 +4766,17 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { const fltSemantics *Sem; // FIXME: This shouldn't be here. if (LHSExt->getSrcTy()->isHalfTy()) - Sem = &APFloat::IEEEhalf; + Sem = &APFloat::IEEEhalf(); else if (LHSExt->getSrcTy()->isFloatTy()) - Sem = &APFloat::IEEEsingle; + Sem = &APFloat::IEEEsingle(); else if (LHSExt->getSrcTy()->isDoubleTy()) - Sem = &APFloat::IEEEdouble; + Sem = &APFloat::IEEEdouble(); else if (LHSExt->getSrcTy()->isFP128Ty()) - Sem = &APFloat::IEEEquad; + Sem = &APFloat::IEEEquad(); else if (LHSExt->getSrcTy()->isX86_FP80Ty()) - Sem = &APFloat::x87DoubleExtended; + Sem = &APFloat::x87DoubleExtended(); else if (LHSExt->getSrcTy()->isPPC_FP128Ty()) - Sem = &APFloat::PPCDoubleDouble; + Sem = &APFloat::PPCDoubleDouble(); else break; @@ -4641,7 +4806,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { break; case Instruction::SIToFP: case Instruction::UIToFP: - if (Instruction *NV = FoldFCmp_IntToFP_Cst(I, LHSI, RHSC)) + if (Instruction *NV = foldFCmpIntToFPConst(I, LHSI, RHSC)) return NV; break; case Instruction::FSub: { @@ -4658,7 +4823,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) if (GV->isConstant() && GV->hasDefinitiveInitializer() && !cast<LoadInst>(LHSI)->isVolatile()) - if (Instruction *Res = FoldCmpLoadFromIndexedGlobal(GEP, GV, I)) + if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV, I)) return Res; } break; @@ -4667,7 +4832,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { break; CallInst *CI = cast<CallInst>(LHSI); - Intrinsic::ID IID = getIntrinsicForCallSite(CI, TLI); + Intrinsic::ID IID = getIntrinsicForCallSite(CI, &TLI); if (IID != Intrinsic::fabs) break; diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/contrib/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index aa421ff..2847ce8 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -84,6 +84,24 @@ static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { if (isa<ConstantInt>(V)) return true; + // A vector of constant integers can be inverted easily. + Constant *CV; + if (V->getType()->isVectorTy() && match(V, PatternMatch::m_Constant(CV))) { + unsigned NumElts = V->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = CV->getAggregateElement(i); + if (!Elt) + return false; + + if (isa<UndefValue>(Elt)) + continue; + + if (!isa<ConstantInt>(Elt)) + return false; + } + return true; + } + // Compares can be inverted if all of their uses are being modified to use the // ~V. if (isa<CmpInst>(V)) @@ -135,33 +153,10 @@ IntrinsicIDToOverflowCheckFlavor(unsigned ID) { } } -/// \brief An IRBuilder inserter that adds new instructions to the instcombine -/// worklist. -class LLVM_LIBRARY_VISIBILITY InstCombineIRInserter - : public IRBuilderDefaultInserter { - InstCombineWorklist &Worklist; - AssumptionCache *AC; - -public: - InstCombineIRInserter(InstCombineWorklist &WL, AssumptionCache *AC) - : Worklist(WL), AC(AC) {} - - void InsertHelper(Instruction *I, const Twine &Name, BasicBlock *BB, - BasicBlock::iterator InsertPt) const { - IRBuilderDefaultInserter::InsertHelper(I, Name, BB, InsertPt); - Worklist.Add(I); - - using namespace llvm::PatternMatch; - if (match(I, m_Intrinsic<Intrinsic::assume>())) - AC->registerAssumption(cast<CallInst>(I)); - } -}; - /// \brief The core instruction combiner logic. /// /// This class provides both the logic to recursively visit instructions and -/// combine them, as well as the pass infrastructure for running this as part -/// of the LLVM pass pipeline. +/// combine them. class LLVM_LIBRARY_VISIBILITY InstCombiner : public InstVisitor<InstCombiner, Instruction *> { // FIXME: These members shouldn't be public. @@ -171,7 +166,7 @@ public: /// \brief An IRBuilder that automatically inserts new instructions into the /// worklist. - typedef IRBuilder<TargetFolder, InstCombineIRInserter> BuilderTy; + typedef IRBuilder<TargetFolder, IRBuilderCallbackInserter> BuilderTy; BuilderTy *Builder; private: @@ -183,10 +178,9 @@ private: AliasAnalysis *AA; // Required analyses. - // FIXME: These can never be null and should be references. - AssumptionCache *AC; - TargetLibraryInfo *TLI; - DominatorTree *DT; + AssumptionCache &AC; + TargetLibraryInfo &TLI; + DominatorTree &DT; const DataLayout &DL; // Optional analyses. When non-null, these can both be used to do better @@ -198,8 +192,8 @@ private: public: InstCombiner(InstCombineWorklist &Worklist, BuilderTy *Builder, bool MinimizeSize, bool ExpensiveCombines, AliasAnalysis *AA, - AssumptionCache *AC, TargetLibraryInfo *TLI, - DominatorTree *DT, const DataLayout &DL, LoopInfo *LI) + AssumptionCache &AC, TargetLibraryInfo &TLI, + DominatorTree &DT, const DataLayout &DL, LoopInfo *LI) : Worklist(Worklist), Builder(Builder), MinimizeSize(MinimizeSize), ExpensiveCombines(ExpensiveCombines), AA(AA), AC(AC), TLI(TLI), DT(DT), DL(DL), LI(LI), MadeIRChange(false) {} @@ -209,15 +203,15 @@ public: /// \returns true if the IR is changed. bool run(); - AssumptionCache *getAssumptionCache() const { return AC; } + AssumptionCache &getAssumptionCache() const { return AC; } const DataLayout &getDataLayout() const { return DL; } - DominatorTree *getDominatorTree() const { return DT; } + DominatorTree &getDominatorTree() const { return DT; } LoopInfo *getLoopInfo() const { return LI; } - TargetLibraryInfo *getTargetLibraryInfo() const { return TLI; } + TargetLibraryInfo &getTargetLibraryInfo() const { return TLI; } // Visitation implementation - Implement instruction combining for different // instruction types. The semantics are as follows: @@ -262,29 +256,8 @@ public: Instruction *visitAShr(BinaryOperator &I); Instruction *visitLShr(BinaryOperator &I); Instruction *commonShiftTransforms(BinaryOperator &I); - Instruction *FoldFCmp_IntToFP_Cst(FCmpInst &I, Instruction *LHSI, - Constant *RHSC); - Instruction *FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, - GlobalVariable *GV, CmpInst &ICI, - ConstantInt *AndCst = nullptr); Instruction *visitFCmpInst(FCmpInst &I); Instruction *visitICmpInst(ICmpInst &I); - Instruction *visitICmpInstWithCastAndCast(ICmpInst &ICI); - Instruction *visitICmpInstWithInstAndIntCst(ICmpInst &ICI, Instruction *LHS, - ConstantInt *RHS); - Instruction *FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, - ConstantInt *DivRHS); - Instruction *FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *DivI, - ConstantInt *DivRHS); - Instruction *FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A, - ConstantInt *CI1, ConstantInt *CI2); - Instruction *FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A, - ConstantInt *CI1, ConstantInt *CI2); - Instruction *FoldICmpAddOpCst(Instruction &ICI, Value *X, ConstantInt *CI, - ICmpInst::Predicate Pred); - Instruction *FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS, - ICmpInst::Predicate Cond, Instruction &I); - Instruction *FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca, Value *Other); Instruction *FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I); Instruction *commonCastTransforms(CastInst &CI); @@ -302,14 +275,8 @@ public: Instruction *visitIntToPtr(IntToPtrInst &CI); Instruction *visitBitCast(BitCastInst &CI); Instruction *visitAddrSpaceCast(AddrSpaceCastInst &CI); - Instruction *FoldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI); - Instruction *FoldSelectIntoOp(SelectInst &SI, Value *, Value *); - Instruction *FoldSPFofSPF(Instruction *Inner, SelectPatternFlavor SPF1, - Value *A, Value *B, Instruction &Outer, - SelectPatternFlavor SPF2, Value *C); Instruction *FoldItoFPtoI(Instruction &FI); Instruction *visitSelectInst(SelectInst &SI); - Instruction *visitSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI); Instruction *visitCallInst(CallInst &CI); Instruction *visitInvokeInst(InvokeInst &II); @@ -333,16 +300,16 @@ public: Instruction *visitVAStartInst(VAStartInst &I); Instruction *visitVACopyInst(VACopyInst &I); - // visitInstruction - Specify what to return for unhandled instructions... + /// Specify what to return for unhandled instructions. Instruction *visitInstruction(Instruction &I) { return nullptr; } - // True when DB dominates all uses of DI execpt UI. - // UI must be in the same block as DI. - // The routine checks that the DI parent and DB are different. + /// True when DB dominates all uses of DI except UI. + /// UI must be in the same block as DI. + /// The routine checks that the DI parent and DB are different. bool dominatesAllUses(const Instruction *DI, const Instruction *UI, const BasicBlock *DB) const; - // Replace select with select operand SIOpd in SI-ICmp sequence when possible + /// Try to replace select with select operand SIOpd in SI-ICmp sequence. bool replacedSelectWithOperand(SelectInst *SI, const ICmpInst *Icmp, const unsigned SIOpd); @@ -353,16 +320,17 @@ private: Value *dyn_castFNegVal(Value *V, bool NoSignedZero = false) const; Type *FindElementAtOffset(PointerType *PtrTy, int64_t Offset, SmallVectorImpl<Value *> &NewIndices); - Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI); - /// \brief Classify whether a cast is worth optimizing. + /// Classify whether a cast is worth optimizing. + /// + /// This is a helper to decide whether the simplification of + /// logic(cast(A), cast(B)) to cast(logic(A, B)) should be performed. + /// + /// \param CI The cast we are interested in. /// - /// Returns true if the cast from "V to Ty" actually results in any code - /// being generated and is interesting to optimize out. If the cast can be - /// eliminated by some other simple transformation, we prefer to do the - /// simplification first. - bool ShouldOptimizeCast(Instruction::CastOps opcode, const Value *V, - Type *Ty); + /// \return true if this cast actually results in any code being generated and + /// if it cannot already be eliminated by some other transformation. + bool shouldOptimizeCast(CastInst *CI); /// \brief Try to optimize a sequence of instructions checking if an operation /// on LHS and RHS overflows. @@ -385,8 +353,22 @@ private: bool transformConstExprCastCall(CallSite CS); Instruction *transformCallThroughTrampoline(CallSite CS, IntrinsicInst *Tramp); - Instruction *transformZExtICmp(ICmpInst *ICI, Instruction &CI, - bool DoXform = true); + + /// Transform (zext icmp) to bitwise / integer operations in order to + /// eliminate it. + /// + /// \param ICI The icmp of the (zext icmp) pair we are interested in. + /// \parem CI The zext of the (zext icmp) pair we are interested in. + /// \param DoTransform Pass false to just test whether the given (zext icmp) + /// would be transformed. Pass true to actually perform the transformation. + /// + /// \return null if the transformation cannot be performed. If the + /// transformation can be performed the new instruction that replaces the + /// (zext icmp) pair will be returned (if \p DoTransform is false the + /// unmodified \p ICI will be returned in this case). + Instruction *transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, + bool DoTransform = true); + Instruction *transformSExtICmp(ICmpInst *ICI, Instruction &CI); bool WillNotOverflowSignedAdd(Value *LHS, Value *RHS, Instruction &CxtI); bool WillNotOverflowSignedSub(Value *LHS, Value *RHS, Instruction &CxtI); @@ -396,6 +378,21 @@ private: Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); + Instruction *shrinkBitwiseLogic(TruncInst &Trunc); + Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); + + /// Determine if a pair of casts can be replaced by a single cast. + /// + /// \param CI1 The first of a pair of casts. + /// \param CI2 The second of a pair of casts. + /// + /// \return 0 if the cast pair cannot be eliminated, otherwise returns an + /// Instruction::CastOps value for a cast that can replace the pair, casting + /// CI1->getSrcTy() to CI2->getDstTy(). + /// + /// \see CastInst::isEliminableCastPair + Instruction::CastOps isEliminableCastPair(const CastInst *CI1, + const CastInst *CI2); public: /// \brief Inserts an instruction \p New before instruction \p Old @@ -476,30 +473,30 @@ public: void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, unsigned Depth, Instruction *CxtI) const { - return llvm::computeKnownBits(V, KnownZero, KnownOne, DL, Depth, AC, CxtI, - DT); + return llvm::computeKnownBits(V, KnownZero, KnownOne, DL, Depth, &AC, CxtI, + &DT); } bool MaskedValueIsZero(Value *V, const APInt &Mask, unsigned Depth = 0, Instruction *CxtI = nullptr) const { - return llvm::MaskedValueIsZero(V, Mask, DL, Depth, AC, CxtI, DT); + return llvm::MaskedValueIsZero(V, Mask, DL, Depth, &AC, CxtI, &DT); } unsigned ComputeNumSignBits(Value *Op, unsigned Depth = 0, Instruction *CxtI = nullptr) const { - return llvm::ComputeNumSignBits(Op, DL, Depth, AC, CxtI, DT); + return llvm::ComputeNumSignBits(Op, DL, Depth, &AC, CxtI, &DT); } void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, unsigned Depth = 0, Instruction *CxtI = nullptr) const { - return llvm::ComputeSignBit(V, KnownZero, KnownOne, DL, Depth, AC, CxtI, - DT); + return llvm::ComputeSignBit(V, KnownZero, KnownOne, DL, Depth, &AC, CxtI, + &DT); } OverflowResult computeOverflowForUnsignedMul(Value *LHS, Value *RHS, const Instruction *CxtI) { - return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, AC, CxtI, DT); + return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, &AC, CxtI, &DT); } OverflowResult computeOverflowForUnsignedAdd(Value *LHS, Value *RHS, const Instruction *CxtI) { - return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, AC, CxtI, DT); + return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT); } private: @@ -539,13 +536,21 @@ private: Value *SimplifyVectorOp(BinaryOperator &Inst); Value *SimplifyBSwap(BinaryOperator &Inst); - // FoldOpIntoPhi - Given a binary operator, cast instruction, or select - // which has a PHI node as operand #0, see if we can fold the instruction - // into the PHI (which is only possible if all operands to the PHI are - // constants). - // + + /// Given a binary operator, cast instruction, or select which has a PHI node + /// as operand #0, see if we can fold the instruction into the PHI (which is + /// only possible if all operands to the PHI are constants). Instruction *FoldOpIntoPhi(Instruction &I); + /// Given an instruction with a select as one operand and a constant as the + /// other operand, try to fold the binary operator into the select arguments. + /// This also works for Cast instructions, which obviously do not have a + /// second operand. + Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI); + + /// This is a convenience wrapper function for the above two functions. + Instruction *foldOpWithConstantIntoOperand(Instruction &I); + /// \brief Try to rotate an operation below a PHI node, using PHI nodes for /// its operands. Instruction *FoldPHIArgOpIntoPHI(PHINode &PN); @@ -554,13 +559,82 @@ private: Instruction *FoldPHIArgLoadIntoPHI(PHINode &PN); Instruction *FoldPHIArgZextsIntoPHI(PHINode &PN); + /// Helper function for FoldPHIArgXIntoPHI() to get debug location for the + /// folded operation. + DebugLoc PHIArgMergedDebugLoc(PHINode &PN); + + Instruction *foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, + ICmpInst::Predicate Cond, Instruction &I); + Instruction *foldAllocaCmp(ICmpInst &ICI, const AllocaInst *Alloca, + const Value *Other); + Instruction *foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, + GlobalVariable *GV, CmpInst &ICI, + ConstantInt *AndCst = nullptr); + Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, + Constant *RHSC); + Instruction *foldICmpAddOpConst(Instruction &ICI, Value *X, ConstantInt *CI, + ICmpInst::Predicate Pred); + Instruction *foldICmpWithCastAndCast(ICmpInst &ICI); + + Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp); + Instruction *foldICmpWithConstant(ICmpInst &Cmp); + Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); + Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); + Instruction *foldICmpBinOp(ICmpInst &Cmp); + Instruction *foldICmpEquality(ICmpInst &Cmp); + + Instruction *foldICmpTruncConstant(ICmpInst &Cmp, Instruction *Trunc, + const APInt *C); + Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, + const APInt *C); + Instruction *foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor, + const APInt *C); + Instruction *foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, + const APInt *C); + Instruction *foldICmpMulConstant(ICmpInst &Cmp, BinaryOperator *Mul, + const APInt *C); + Instruction *foldICmpShlConstant(ICmpInst &Cmp, BinaryOperator *Shl, + const APInt *C); + Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr, + const APInt *C); + Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, + const APInt *C); + Instruction *foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div, + const APInt *C); + Instruction *foldICmpSubConstant(ICmpInst &Cmp, BinaryOperator *Sub, + const APInt *C); + Instruction *foldICmpAddConstant(ICmpInst &Cmp, BinaryOperator *Add, + const APInt *C); + Instruction *foldICmpAndConstConst(ICmpInst &Cmp, BinaryOperator *And, + const APInt *C1); + Instruction *foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, + const APInt *C1, const APInt *C2); + Instruction *foldICmpShrConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1, + const APInt &C2); + Instruction *foldICmpShlConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1, + const APInt &C2); + + Instruction *foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, + BinaryOperator *BO, + const APInt *C); + Instruction *foldICmpIntrinsicWithConstant(ICmpInst &ICI, const APInt *C); + + // Helpers of visitSelectInst(). + Instruction *foldSelectExtConst(SelectInst &Sel); + Instruction *foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI); + Instruction *foldSelectIntoOp(SelectInst &SI, Value *, Value *); + Instruction *foldSPFofSPF(Instruction *Inner, SelectPatternFlavor SPF1, + Value *A, Value *B, Instruction &Outer, + SelectPatternFlavor SPF2, Value *C); + Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI); + Instruction *OptAndOp(Instruction *Op, ConstantInt *OpRHS, ConstantInt *AndRHS, BinaryOperator &TheAnd); Value *FoldLogicalPlusAnd(Value *LHS, Value *RHS, ConstantInt *Mask, bool isSub, Instruction &I); - Value *InsertRangeTest(Value *V, Constant *Lo, Constant *Hi, bool isSigned, - bool Inside); + Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, + bool isSigned, bool Inside); Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI); Instruction *MatchBSwap(BinaryOperator &I); bool SimplifyStoreAtEndOfBlock(StoreInst &SI); diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index d88456e..49e516e 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/Loads.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/IntrinsicInst.h" @@ -59,14 +60,14 @@ isOnlyCopiedFromConstantGlobal(Value *V, MemTransferInst *&TheCopy, // eliminate the markers. SmallVector<std::pair<Value *, bool>, 35> ValuesToInspect; - ValuesToInspect.push_back(std::make_pair(V, false)); + ValuesToInspect.emplace_back(V, false); while (!ValuesToInspect.empty()) { auto ValuePair = ValuesToInspect.pop_back_val(); const bool IsOffset = ValuePair.second; for (auto &U : ValuePair.first->uses()) { - Instruction *I = cast<Instruction>(U.getUser()); + auto *I = cast<Instruction>(U.getUser()); - if (LoadInst *LI = dyn_cast<LoadInst>(I)) { + if (auto *LI = dyn_cast<LoadInst>(I)) { // Ignore non-volatile loads, they are always ok. if (!LI->isSimple()) return false; continue; @@ -74,14 +75,13 @@ isOnlyCopiedFromConstantGlobal(Value *V, MemTransferInst *&TheCopy, if (isa<BitCastInst>(I) || isa<AddrSpaceCastInst>(I)) { // If uses of the bitcast are ok, we are ok. - ValuesToInspect.push_back(std::make_pair(I, IsOffset)); + ValuesToInspect.emplace_back(I, IsOffset); continue; } - if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I)) { + if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { // If the GEP has all zero indices, it doesn't offset the pointer. If it // doesn't, it does. - ValuesToInspect.push_back( - std::make_pair(I, IsOffset || !GEP->hasAllZeroIndices())); + ValuesToInspect.emplace_back(I, IsOffset || !GEP->hasAllZeroIndices()); continue; } @@ -286,7 +286,7 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { SmallVector<Instruction *, 4> ToDelete; if (MemTransferInst *Copy = isOnlyCopiedFromConstantGlobal(&AI, ToDelete)) { unsigned SourceAlign = getOrEnforceKnownAlignment( - Copy->getSource(), AI.getAlignment(), DL, &AI, AC, DT); + Copy->getSource(), AI.getAlignment(), DL, &AI, &AC, &DT); if (AI.getAlignment() <= SourceAlign) { DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); @@ -308,6 +308,11 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { return visitAllocSite(AI); } +// Are we allowed to form a atomic load or store of this type? +static bool isSupportedAtomicType(Type *Ty) { + return Ty->isIntegerTy() || Ty->isPointerTy() || Ty->isFloatingPointTy(); +} + /// \brief Helper to combine a load to a new type. /// /// This just does the work of combining a load to a new type. It handles @@ -319,6 +324,9 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { /// point the \c InstCombiner currently is using. static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewTy, const Twine &Suffix = "") { + assert((!LI.isAtomic() || isSupportedAtomicType(NewTy)) && + "can't fold an atomic load to requested type"); + Value *Ptr = LI.getPointerOperand(); unsigned AS = LI.getPointerAddressSpace(); SmallVector<std::pair<unsigned, MDNode *>, 8> MD; @@ -380,8 +388,16 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT break; case LLVMContext::MD_range: // FIXME: It would be nice to propagate this in some way, but the type - // conversions make it hard. If the new type is a pointer, we could - // translate it to !nonnull metadata. + // conversions make it hard. + + // If it's a pointer now and the range does not contain 0, make it !nonnull. + if (NewTy->isPointerTy()) { + unsigned BitWidth = IC.getDataLayout().getTypeSizeInBits(NewTy); + if (!getConstantRangeFromMetadata(*N).contains(APInt(BitWidth, 0))) { + MDNode *NN = MDNode::get(LI.getContext(), None); + NewLoad->setMetadata(LLVMContext::MD_nonnull, NN); + } + } break; } } @@ -392,6 +408,9 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT /// /// Returns the newly created store instruction. static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value *V) { + assert((!SI.isAtomic() || isSupportedAtomicType(V->getType())) && + "can't fold an atomic store of requested type"); + Value *Ptr = SI.getPointerOperand(); unsigned AS = SI.getPointerAddressSpace(); SmallVector<std::pair<unsigned, MDNode *>, 8> MD; @@ -466,6 +485,10 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { if (LI.use_empty()) return nullptr; + // swifterror values can't be bitcasted. + if (LI.getPointerOperand()->isSwiftError()) + return nullptr; + Type *Ty = LI.getType(); const DataLayout &DL = IC.getDataLayout(); @@ -475,10 +498,12 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { // size is a legal integer type. if (!Ty->isIntegerTy() && Ty->isSized() && DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) && - DL.getTypeStoreSizeInBits(Ty) == DL.getTypeSizeInBits(Ty)) { - if (std::all_of(LI.user_begin(), LI.user_end(), [&LI](User *U) { + DL.getTypeStoreSizeInBits(Ty) == DL.getTypeSizeInBits(Ty) && + !DL.isNonIntegralPointerType(Ty)) { + if (all_of(LI.users(), [&LI](User *U) { auto *SI = dyn_cast<StoreInst>(U); - return SI && SI->getPointerOperand() != &LI; + return SI && SI->getPointerOperand() != &LI && + !SI->getPointerOperand()->isSwiftError(); })) { LoadInst *NewLoad = combineLoadToNewType( IC, LI, @@ -501,14 +526,14 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { // as long as those are noops (i.e., the source or dest type have the same // bitwidth as the target's pointers). if (LI.hasOneUse()) - if (auto* CI = dyn_cast<CastInst>(LI.user_back())) { - if (CI->isNoopCast(DL)) { - LoadInst *NewLoad = combineLoadToNewType(IC, LI, CI->getDestTy()); - CI->replaceAllUsesWith(NewLoad); - IC.eraseInstFromFunction(*CI); - return &LI; - } - } + if (auto* CI = dyn_cast<CastInst>(LI.user_back())) + if (CI->isNoopCast(DL)) + if (!LI.isAtomic() || isSupportedAtomicType(CI->getDestTy())) { + LoadInst *NewLoad = combineLoadToNewType(IC, LI, CI->getDestTy()); + CI->replaceAllUsesWith(NewLoad); + IC.eraseInstFromFunction(*CI); + return &LI; + } // FIXME: We should also canonicalize loads of vectors when their elements are // cast to other types. @@ -802,7 +827,7 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // Attempt to improve the alignment. unsigned KnownAlign = getOrEnforceKnownAlignment( - Op, DL.getPrefTypeAlignment(LI.getType()), DL, &LI, AC, DT); + Op, DL.getPrefTypeAlignment(LI.getType()), DL, &LI, &AC, &DT); unsigned LoadAlign = LI.getAlignment(); unsigned EffectiveLoadAlign = LoadAlign != 0 ? LoadAlign : DL.getABITypeAlignment(LI.getType()); @@ -825,22 +850,11 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // where there are several consecutive memory accesses to the same location, // separated by a few arithmetic operations. BasicBlock::iterator BBI(LI); - AAMDNodes AATags; bool IsLoadCSE = false; - if (Value *AvailableVal = - FindAvailableLoadedValue(&LI, LI.getParent(), BBI, - DefMaxInstsToScan, AA, &AATags, &IsLoadCSE)) { - if (IsLoadCSE) { - LoadInst *NLI = cast<LoadInst>(AvailableVal); - unsigned KnownIDs[] = { - LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, - LLVMContext::MD_noalias, LLVMContext::MD_range, - LLVMContext::MD_invariant_load, LLVMContext::MD_nonnull, - LLVMContext::MD_invariant_group, LLVMContext::MD_align, - LLVMContext::MD_dereferenceable, - LLVMContext::MD_dereferenceable_or_null}; - combineMetadata(NLI, &LI, KnownIDs); - }; + if (Value *AvailableVal = FindAvailableLoadedValue( + &LI, LI.getParent(), BBI, DefMaxInstsToScan, AA, &IsLoadCSE)) { + if (IsLoadCSE) + combineMetadataForCSE(cast<LoadInst>(AvailableVal), &LI); return replaceInstUsesWith( LI, Builder->CreateBitOrPointerCast(AvailableVal, LI.getType(), @@ -1005,19 +1019,26 @@ static bool combineStoreToValueType(InstCombiner &IC, StoreInst &SI) { if (!SI.isUnordered()) return false; + // swifterror values can't be bitcasted. + if (SI.getPointerOperand()->isSwiftError()) + return false; + Value *V = SI.getValueOperand(); // Fold away bit casts of the stored value by storing the original type. if (auto *BC = dyn_cast<BitCastInst>(V)) { V = BC->getOperand(0); - combineStoreToNewValue(IC, SI, V); - return true; + if (!SI.isAtomic() || isSupportedAtomicType(V->getType())) { + combineStoreToNewValue(IC, SI, V); + return true; + } } - if (Value *U = likeBitCastFromVector(IC, V)) { - combineStoreToNewValue(IC, SI, U); - return true; - } + if (Value *U = likeBitCastFromVector(IC, V)) + if (!SI.isAtomic() || isSupportedAtomicType(U->getType())) { + combineStoreToNewValue(IC, SI, U); + return true; + } // FIXME: We should also canonicalize stores of vectors when their elements // are cast to other types. @@ -1169,7 +1190,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { // Attempt to improve the alignment. unsigned KnownAlign = getOrEnforceKnownAlignment( - Ptr, DL.getPrefTypeAlignment(Val->getType()), DL, &SI, AC, DT); + Ptr, DL.getPrefTypeAlignment(Val->getType()), DL, &SI, &AC, &DT); unsigned StoreAlign = SI.getAlignment(); unsigned EffectiveStoreAlign = StoreAlign != 0 ? StoreAlign : DL.getABITypeAlignment(Val->getType()); @@ -1293,7 +1314,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { assert(SI.isUnordered() && "this code has not been auditted for volatile or ordered store case"); - + BasicBlock *StoreBB = SI.getParent(); // Check to see if the successor block has exactly two incoming edges. If diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 788097f..45a19fb 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -48,8 +48,8 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, BinaryOperator *I = dyn_cast<BinaryOperator>(V); if (I && I->isLogicalShift() && isKnownToBeAPowerOfTwo(I->getOperand(0), IC.getDataLayout(), false, 0, - IC.getAssumptionCache(), &CxtI, - IC.getDominatorTree())) { + &IC.getAssumptionCache(), &CxtI, + &IC.getDominatorTree())) { // We know that this is an exact/nuw shift and that the input is a // non-zero context as well. if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { @@ -179,7 +179,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifyMulInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (Value *V = SimplifyUsingDistributiveLaws(I)) @@ -267,14 +267,8 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { // Simplify mul instructions with a constant RHS. if (isa<Constant>(Op1)) { - // Try to fold constant mul into select arguments. - if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = FoldOpIntoSelect(I, SI)) - return R; - - if (isa<PHINode>(Op0)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + if (Instruction *FoldedMul = foldOpWithConstantIntoOperand(I)) + return FoldedMul; // Canonicalize (X+C1)*CI -> X*CI+C1*CI. { @@ -389,6 +383,80 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } + // Check for (mul (sext x), y), see if we can merge this into an + // integer mul followed by a sext. + if (SExtInst *Op0Conv = dyn_cast<SExtInst>(Op0)) { + // (mul (sext x), cst) --> (sext (mul x, cst')) + if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { + if (Op0Conv->hasOneUse()) { + Constant *CI = + ConstantExpr::getTrunc(Op1C, Op0Conv->getOperand(0)->getType()); + if (ConstantExpr::getSExt(CI, I.getType()) == Op1C && + WillNotOverflowSignedMul(Op0Conv->getOperand(0), CI, I)) { + // Insert the new, smaller mul. + Value *NewMul = + Builder->CreateNSWMul(Op0Conv->getOperand(0), CI, "mulconv"); + return new SExtInst(NewMul, I.getType()); + } + } + } + + // (mul (sext x), (sext y)) --> (sext (mul int x, y)) + if (SExtInst *Op1Conv = dyn_cast<SExtInst>(Op1)) { + // Only do this if x/y have the same type, if at last one of them has a + // single use (so we don't increase the number of sexts), and if the + // integer mul will not overflow. + if (Op0Conv->getOperand(0)->getType() == + Op1Conv->getOperand(0)->getType() && + (Op0Conv->hasOneUse() || Op1Conv->hasOneUse()) && + WillNotOverflowSignedMul(Op0Conv->getOperand(0), + Op1Conv->getOperand(0), I)) { + // Insert the new integer mul. + Value *NewMul = Builder->CreateNSWMul( + Op0Conv->getOperand(0), Op1Conv->getOperand(0), "mulconv"); + return new SExtInst(NewMul, I.getType()); + } + } + } + + // Check for (mul (zext x), y), see if we can merge this into an + // integer mul followed by a zext. + if (auto *Op0Conv = dyn_cast<ZExtInst>(Op0)) { + // (mul (zext x), cst) --> (zext (mul x, cst')) + if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { + if (Op0Conv->hasOneUse()) { + Constant *CI = + ConstantExpr::getTrunc(Op1C, Op0Conv->getOperand(0)->getType()); + if (ConstantExpr::getZExt(CI, I.getType()) == Op1C && + computeOverflowForUnsignedMul(Op0Conv->getOperand(0), CI, &I) == + OverflowResult::NeverOverflows) { + // Insert the new, smaller mul. + Value *NewMul = + Builder->CreateNUWMul(Op0Conv->getOperand(0), CI, "mulconv"); + return new ZExtInst(NewMul, I.getType()); + } + } + } + + // (mul (zext x), (zext y)) --> (zext (mul int x, y)) + if (auto *Op1Conv = dyn_cast<ZExtInst>(Op1)) { + // Only do this if x/y have the same type, if at last one of them has a + // single use (so we don't increase the number of zexts), and if the + // integer mul will not overflow. + if (Op0Conv->getOperand(0)->getType() == + Op1Conv->getOperand(0)->getType() && + (Op0Conv->hasOneUse() || Op1Conv->hasOneUse()) && + computeOverflowForUnsignedMul(Op0Conv->getOperand(0), + Op1Conv->getOperand(0), + &I) == OverflowResult::NeverOverflows) { + // Insert the new integer mul. + Value *NewMul = Builder->CreateNUWMul( + Op0Conv->getOperand(0), Op1Conv->getOperand(0), "mulconv"); + return new ZExtInst(NewMul, I.getType()); + } + } + } + if (!I.hasNoSignedWrap() && WillNotOverflowSignedMul(Op0, Op1, I)) { Changed = true; I.setHasNoSignedWrap(true); @@ -545,21 +613,15 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { std::swap(Op0, Op1); if (Value *V = - SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) + SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); bool AllowReassociate = I.hasUnsafeAlgebra(); // Simplify mul instructions with a constant RHS. if (isa<Constant>(Op1)) { - // Try to fold constant mul into select arguments. - if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = FoldOpIntoSelect(I, SI)) - return R; - - if (isa<PHINode>(Op0)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + if (Instruction *FoldedMul = foldOpWithConstantIntoOperand(I)) + return FoldedMul; // (fmul X, -1.0) --> (fsub -0.0, X) if (match(Op1, m_SpecificFP(-1.0))) { @@ -709,7 +771,6 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { BuilderTy::FastMathFlagGuard Guard(*Builder); Builder->setFastMathFlags(I.getFastMathFlags()); Value *T = Builder->CreateFMul(Opnd1, Opnd1); - Value *R = Builder->CreateFMul(T, Y); R->takeName(&I); return replaceInstUsesWith(I, R); @@ -883,14 +944,9 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { } } - if (*C2 != 0) { // avoid X udiv 0 - if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = FoldOpIntoSelect(I, SI)) - return R; - if (isa<PHINode>(Op0)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; - } + if (*C2 != 0) // avoid X udiv 0 + if (Instruction *FoldedDiv = foldOpWithConstantIntoOperand(I)) + return FoldedDiv; } } @@ -991,19 +1047,22 @@ static Instruction *foldUDivNegCst(Value *Op0, Value *Op1, } // X udiv (C1 << N), where C1 is "1<<C2" --> X >> (N+C2) +// X udiv (zext (C1 << N)), where C1 is "1<<C2" --> X >> (N+C2) static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I, InstCombiner &IC) { - Instruction *ShiftLeft = cast<Instruction>(Op1); - if (isa<ZExtInst>(ShiftLeft)) - ShiftLeft = cast<Instruction>(ShiftLeft->getOperand(0)); - - const APInt &CI = - cast<Constant>(ShiftLeft->getOperand(0))->getUniqueInteger(); - Value *N = ShiftLeft->getOperand(1); - if (CI != 1) - N = IC.Builder->CreateAdd(N, ConstantInt::get(N->getType(), CI.logBase2())); - if (ZExtInst *Z = dyn_cast<ZExtInst>(Op1)) - N = IC.Builder->CreateZExt(N, Z->getDestTy()); + Value *ShiftLeft; + if (!match(Op1, m_ZExt(m_Value(ShiftLeft)))) + ShiftLeft = Op1; + + const APInt *CI; + Value *N; + if (!match(ShiftLeft, m_Shl(m_APInt(CI), m_Value(N)))) + llvm_unreachable("match should never fail here!"); + if (*CI != 1) + N = IC.Builder->CreateAdd(N, + ConstantInt::get(N->getType(), CI->logBase2())); + if (Op1 != ShiftLeft) + N = IC.Builder->CreateZExt(N, Op1->getType()); BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, N); if (I.isExact()) LShr->setIsExact(); @@ -1059,7 +1118,7 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifyUDivInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // Handle the integer div common cases @@ -1132,7 +1191,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifySDivInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // Handle the integer div common cases @@ -1195,7 +1254,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { return BO; } - if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, AC, &I, DT)) { + if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, &AC, &I, &DT)) { // X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y) // Safe because the only negative value (1 << Y) can take on is // INT_MIN, and X sdiv INT_MIN == X udiv INT_MIN == 0 if X doesn't have @@ -1247,7 +1306,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifyFDivInst(Op0, Op1, I.getFastMathFlags(), - DL, TLI, DT, AC)) + DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (isa<Constant>(Op0)) @@ -1367,6 +1426,16 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { } } + Value *LHS; + Value *RHS; + + // -x / -y -> x / y + if (match(Op0, m_FNeg(m_Value(LHS))) && match(Op1, m_FNeg(m_Value(RHS)))) { + I.setOperand(0, LHS); + I.setOperand(1, RHS); + return &I; + } + return nullptr; } @@ -1421,7 +1490,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifyURemInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (Instruction *common = commonIRemTransforms(I)) @@ -1434,7 +1503,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { I.getType()); // X urem Y -> X and Y-1, where Y is a power of 2, - if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, AC, &I, DT)) { + if (isKnownToBeAPowerOfTwo(Op1, DL, /*OrZero*/ true, 0, &AC, &I, &DT)) { Constant *N1 = Constant::getAllOnesValue(I.getType()); Value *Add = Builder->CreateAdd(Op1, N1); return BinaryOperator::CreateAnd(Op0, Add); @@ -1447,6 +1516,14 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { return replaceInstUsesWith(I, Ext); } + // X urem C -> X < C ? X : X - C, where C >= signbit. + const APInt *DivisorC; + if (match(Op1, m_APInt(DivisorC)) && DivisorC->isNegative()) { + Value *Cmp = Builder->CreateICmpULT(Op0, Op1); + Value *Sub = Builder->CreateSub(Op0, Op1); + return SelectInst::Create(Cmp, Op0, Sub); + } + return nullptr; } @@ -1456,7 +1533,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AC)) + if (Value *V = SimplifySRemInst(Op0, Op1, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // Handle the integer rem common cases @@ -1532,7 +1609,7 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifyFRemInst(Op0, Op1, I.getFastMathFlags(), - DL, TLI, DT, AC)) + DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); // Handle cases involving: rem X, (select Cond, Y, Z) diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 79a4912..4cbffe9 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -18,11 +18,27 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/IR/DebugInfo.h" using namespace llvm; using namespace llvm::PatternMatch; #define DEBUG_TYPE "instcombine" +/// The PHI arguments will be folded into a single operation with a PHI node +/// as input. The debug location of the single operation will be the merged +/// locations of the original PHI node arguments. +DebugLoc InstCombiner::PHIArgMergedDebugLoc(PHINode &PN) { + auto *FirstInst = cast<Instruction>(PN.getIncomingValue(0)); + const DILocation *Loc = FirstInst->getDebugLoc(); + + for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) { + auto *I = cast<Instruction>(PN.getIncomingValue(i)); + Loc = DILocation::getMergedLocation(Loc, I->getDebugLoc()); + } + + return Loc; +} + /// If we have something like phi [add (a,b), add(a,c)] and if a/b/c and the /// adds all have a single use, turn this into a phi and a single binop. Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { @@ -101,7 +117,7 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { if (CmpInst *CIOp = dyn_cast<CmpInst>(FirstInst)) { CmpInst *NewCI = CmpInst::Create(CIOp->getOpcode(), CIOp->getPredicate(), LHSVal, RHSVal); - NewCI->setDebugLoc(FirstInst->getDebugLoc()); + NewCI->setDebugLoc(PHIArgMergedDebugLoc(PN)); return NewCI; } @@ -114,7 +130,7 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) NewBinOp->andIRFlags(PN.getIncomingValue(i)); - NewBinOp->setDebugLoc(FirstInst->getDebugLoc()); + NewBinOp->setDebugLoc(PHIArgMergedDebugLoc(PN)); return NewBinOp; } @@ -223,7 +239,7 @@ Instruction *InstCombiner::FoldPHIArgGEPIntoPHI(PHINode &PN) { GetElementPtrInst::Create(FirstInst->getSourceElementType(), Base, makeArrayRef(FixedOperands).slice(1)); if (AllInBounds) NewGEP->setIsInBounds(); - NewGEP->setDebugLoc(FirstInst->getDebugLoc()); + NewGEP->setDebugLoc(PHIArgMergedDebugLoc(PN)); return NewGEP; } @@ -383,7 +399,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { for (Value *IncValue : PN.incoming_values()) cast<LoadInst>(IncValue)->setVolatile(false); - NewLI->setDebugLoc(FirstLI->getDebugLoc()); + NewLI->setDebugLoc(PHIArgMergedDebugLoc(PN)); return NewLI; } @@ -549,7 +565,7 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { if (CastInst *FirstCI = dyn_cast<CastInst>(FirstInst)) { CastInst *NewCI = CastInst::Create(FirstCI->getOpcode(), PhiVal, PN.getType()); - NewCI->setDebugLoc(FirstInst->getDebugLoc()); + NewCI->setDebugLoc(PHIArgMergedDebugLoc(PN)); return NewCI; } @@ -560,14 +576,14 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) BinOp->andIRFlags(PN.getIncomingValue(i)); - BinOp->setDebugLoc(FirstInst->getDebugLoc()); + BinOp->setDebugLoc(PHIArgMergedDebugLoc(PN)); return BinOp; } CmpInst *CIOp = cast<CmpInst>(FirstInst); CmpInst *NewCI = CmpInst::Create(CIOp->getOpcode(), CIOp->getPredicate(), PhiVal, ConstantOp); - NewCI->setDebugLoc(FirstInst->getDebugLoc()); + NewCI->setDebugLoc(PHIArgMergedDebugLoc(PN)); return NewCI; } @@ -835,8 +851,8 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // needed piece. if (PHINode *OldInVal = dyn_cast<PHINode>(PN->getIncomingValue(i))) if (PHIsInspected.count(OldInVal)) { - unsigned RefPHIId = std::find(PHIsToSlice.begin(),PHIsToSlice.end(), - OldInVal)-PHIsToSlice.begin(); + unsigned RefPHIId = + find(PHIsToSlice, OldInVal) - PHIsToSlice.begin(); PHIUsers.push_back(PHIUsageRecord(RefPHIId, Offset, cast<Instruction>(Res))); ++UserE; @@ -864,7 +880,7 @@ Instruction *InstCombiner::SliceUpIllegalIntegerPHI(PHINode &FirstPhi) { // PHINode simplification // Instruction *InstCombiner::visitPHINode(PHINode &PN) { - if (Value *V = SimplifyInstruction(&PN, DL, TLI, DT, AC)) + if (Value *V = SimplifyInstruction(&PN, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(PN, V); if (Instruction *Result = FoldPHIArgZextsIntoPHI(PN)) @@ -921,7 +937,7 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) { for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { Instruction *CtxI = PN.getIncomingBlock(i)->getTerminator(); Value *VA = PN.getIncomingValue(i); - if (isKnownNonZero(VA, DL, 0, AC, CtxI, DT)) { + if (isKnownNonZero(VA, DL, 0, &AC, CtxI, &DT)) { if (!NonZeroConst) NonZeroConst = GetAnyNonZeroConstInt(PN); PN.setIncomingValue(i, NonZeroConst); diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 8f1ff8a..3664484 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -15,6 +15,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/PatternMatch.h" using namespace llvm; using namespace PatternMatch; @@ -78,7 +79,7 @@ static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy *Builder, /// a bitmask indicating which operands of this instruction are foldable if they /// equal the other incoming value of the select. /// -static unsigned GetSelectFoldableOperands(Instruction *I) { +static unsigned getSelectFoldableOperands(Instruction *I) { switch (I->getOpcode()) { case Instruction::Add: case Instruction::Mul: @@ -98,7 +99,7 @@ static unsigned GetSelectFoldableOperands(Instruction *I) { /// For the same transformation as the previous function, return the identity /// constant that goes into the select. -static Constant *GetSelectFoldableConstant(Instruction *I) { +static Constant *getSelectFoldableConstant(Instruction *I) { switch (I->getOpcode()) { default: llvm_unreachable("This cannot happen!"); case Instruction::Add: @@ -117,7 +118,7 @@ static Constant *GetSelectFoldableConstant(Instruction *I) { } /// We have (select c, TI, FI), and we know that TI and FI have the same opcode. -Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI, +Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI) { // If this is a cast from the same type, merge. if (TI->getNumOperands() == 1 && TI->isCast()) { @@ -154,19 +155,19 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI, } // Fold this by inserting a select from the input values. - Value *NewSI = Builder->CreateSelect(SI.getCondition(), TI->getOperand(0), - FI->getOperand(0), SI.getName()+".v"); + Value *NewSI = + Builder->CreateSelect(SI.getCondition(), TI->getOperand(0), + FI->getOperand(0), SI.getName() + ".v", &SI); return CastInst::Create(Instruction::CastOps(TI->getOpcode()), NewSI, TI->getType()); } - // TODO: This function ends awkwardly in unreachable - fix to be more normal. - // Only handle binary operators with one-use here. As with the cast case // above, it may be possible to relax the one-use constraint, but that needs // be examined carefully since it may not reduce the total number of // instructions. - if (!isa<BinaryOperator>(TI) || !TI->hasOneUse() || !FI->hasOneUse()) + BinaryOperator *BO = dyn_cast<BinaryOperator>(TI); + if (!BO || !TI->hasOneUse() || !FI->hasOneUse()) return nullptr; // Figure out if the operations have any operands in common. @@ -199,16 +200,11 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI, } // If we reach here, they do have operations in common. - Value *NewSI = Builder->CreateSelect(SI.getCondition(), OtherOpT, - OtherOpF, SI.getName()+".v"); - - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(TI)) { - if (MatchIsOpZero) - return BinaryOperator::Create(BO->getOpcode(), MatchOp, NewSI); - else - return BinaryOperator::Create(BO->getOpcode(), NewSI, MatchOp); - } - llvm_unreachable("Shouldn't get here"); + Value *NewSI = Builder->CreateSelect(SI.getCondition(), OtherOpT, OtherOpF, + SI.getName() + ".v", &SI); + Value *Op0 = MatchIsOpZero ? MatchOp : NewSI; + Value *Op1 = MatchIsOpZero ? NewSI : MatchOp; + return BinaryOperator::Create(BO->getOpcode(), Op0, Op1); } static bool isSelect01(Constant *C1, Constant *C2) { @@ -226,14 +222,14 @@ static bool isSelect01(Constant *C1, Constant *C2) { /// Try to fold the select into one of the operands to allow further /// optimization. -Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, +Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, Value *FalseVal) { // See the comment above GetSelectFoldableOperands for a description of the // transformation we are doing here. if (Instruction *TVI = dyn_cast<Instruction>(TrueVal)) { if (TVI->hasOneUse() && TVI->getNumOperands() == 2 && !isa<Constant>(FalseVal)) { - if (unsigned SFO = GetSelectFoldableOperands(TVI)) { + if (unsigned SFO = getSelectFoldableOperands(TVI)) { unsigned OpToFold = 0; if ((SFO & 1) && FalseVal == TVI->getOperand(0)) { OpToFold = 1; @@ -242,7 +238,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, } if (OpToFold) { - Constant *C = GetSelectFoldableConstant(TVI); + Constant *C = getSelectFoldableConstant(TVI); Value *OOp = TVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. @@ -263,7 +259,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, if (Instruction *FVI = dyn_cast<Instruction>(FalseVal)) { if (FVI->hasOneUse() && FVI->getNumOperands() == 2 && !isa<Constant>(TrueVal)) { - if (unsigned SFO = GetSelectFoldableOperands(FVI)) { + if (unsigned SFO = getSelectFoldableOperands(FVI)) { unsigned OpToFold = 0; if ((SFO & 1) && TrueVal == FVI->getOperand(0)) { OpToFold = 1; @@ -272,7 +268,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, } if (OpToFold) { - Constant *C = GetSelectFoldableConstant(FVI); + Constant *C = getSelectFoldableConstant(FVI); Value *OOp = FVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. @@ -411,103 +407,151 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, return nullptr; } +/// Return true if we find and adjust an icmp+select pattern where the compare +/// is with a constant that can be incremented or decremented to match the +/// minimum or maximum idiom. +static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *CmpLHS = Cmp.getOperand(0); + Value *CmpRHS = Cmp.getOperand(1); + Value *TrueVal = Sel.getTrueValue(); + Value *FalseVal = Sel.getFalseValue(); + + // We may move or edit the compare, so make sure the select is the only user. + const APInt *CmpC; + if (!Cmp.hasOneUse() || !match(CmpRHS, m_APInt(CmpC))) + return false; + + // These transforms only work for selects of integers or vector selects of + // integer vectors. + Type *SelTy = Sel.getType(); + auto *SelEltTy = dyn_cast<IntegerType>(SelTy->getScalarType()); + if (!SelEltTy || SelTy->isVectorTy() != Cmp.getType()->isVectorTy()) + return false; + + Constant *AdjustedRHS; + if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_SGT) + AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC + 1); + else if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) + AdjustedRHS = ConstantInt::get(CmpRHS->getType(), *CmpC - 1); + else + return false; + + // X > C ? X : C+1 --> X < C+1 ? C+1 : X + // X < C ? X : C-1 --> X > C-1 ? C-1 : X + if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) || + (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) { + ; // Nothing to do here. Values match without any sign/zero extension. + } + // Types do not match. Instead of calculating this with mixed types, promote + // all to the larger type. This enables scalar evolution to analyze this + // expression. + else if (CmpRHS->getType()->getScalarSizeInBits() < SelEltTy->getBitWidth()) { + Constant *SextRHS = ConstantExpr::getSExt(AdjustedRHS, SelTy); + + // X = sext x; x >s c ? X : C+1 --> X = sext x; X <s C+1 ? C+1 : X + // X = sext x; x <s c ? X : C-1 --> X = sext x; X >s C-1 ? C-1 : X + // X = sext x; x >u c ? X : C+1 --> X = sext x; X <u C+1 ? C+1 : X + // X = sext x; x <u c ? X : C-1 --> X = sext x; X >u C-1 ? C-1 : X + if (match(TrueVal, m_SExt(m_Specific(CmpLHS))) && SextRHS == FalseVal) { + CmpLHS = TrueVal; + AdjustedRHS = SextRHS; + } else if (match(FalseVal, m_SExt(m_Specific(CmpLHS))) && + SextRHS == TrueVal) { + CmpLHS = FalseVal; + AdjustedRHS = SextRHS; + } else if (Cmp.isUnsigned()) { + Constant *ZextRHS = ConstantExpr::getZExt(AdjustedRHS, SelTy); + // X = zext x; x >u c ? X : C+1 --> X = zext x; X <u C+1 ? C+1 : X + // X = zext x; x <u c ? X : C-1 --> X = zext x; X >u C-1 ? C-1 : X + // zext + signed compare cannot be changed: + // 0xff <s 0x00, but 0x00ff >s 0x0000 + if (match(TrueVal, m_ZExt(m_Specific(CmpLHS))) && ZextRHS == FalseVal) { + CmpLHS = TrueVal; + AdjustedRHS = ZextRHS; + } else if (match(FalseVal, m_ZExt(m_Specific(CmpLHS))) && + ZextRHS == TrueVal) { + CmpLHS = FalseVal; + AdjustedRHS = ZextRHS; + } else { + return false; + } + } else { + return false; + } + } else { + return false; + } + + Pred = ICmpInst::getSwappedPredicate(Pred); + CmpRHS = AdjustedRHS; + std::swap(FalseVal, TrueVal); + Cmp.setPredicate(Pred); + Cmp.setOperand(0, CmpLHS); + Cmp.setOperand(1, CmpRHS); + Sel.setOperand(1, TrueVal); + Sel.setOperand(2, FalseVal); + Sel.swapProfMetadata(); + + // Move the compare instruction right before the select instruction. Otherwise + // the sext/zext value may be defined after the compare instruction uses it. + Cmp.moveBefore(&Sel); + + return true; +} + +/// If this is an integer min/max where the select's 'true' operand is a +/// constant, canonicalize that constant to the 'false' operand: +/// select (icmp Pred X, C), C, X --> select (icmp Pred' X, C), X, C +static Instruction * +canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, + InstCombiner::BuilderTy &Builder) { + // TODO: We should also canonicalize min/max when the select has a different + // constant value than the cmp constant, but we need to fix the backend first. + if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1)) || + !isa<Constant>(Sel.getTrueValue()) || + isa<Constant>(Sel.getFalseValue()) || + Cmp.getOperand(1) != Sel.getTrueValue()) + return nullptr; + + // Canonicalize the compare predicate based on whether we have min or max. + Value *LHS, *RHS; + ICmpInst::Predicate NewPred; + SelectPatternResult SPR = matchSelectPattern(&Sel, LHS, RHS); + switch (SPR.Flavor) { + case SPF_SMIN: NewPred = ICmpInst::ICMP_SLT; break; + case SPF_UMIN: NewPred = ICmpInst::ICMP_ULT; break; + case SPF_SMAX: NewPred = ICmpInst::ICMP_SGT; break; + case SPF_UMAX: NewPred = ICmpInst::ICMP_UGT; break; + default: return nullptr; + } + + // Canonicalize the constant to the right side. + if (isa<Constant>(LHS)) + std::swap(LHS, RHS); + + Value *NewCmp = Builder.CreateICmp(NewPred, LHS, RHS); + SelectInst *NewSel = SelectInst::Create(NewCmp, LHS, RHS, "", nullptr, &Sel); + + // We swapped the select operands, so swap the metadata too. + NewSel->swapProfMetadata(); + return NewSel; +} + /// Visit a SelectInst that has an ICmpInst as its first operand. -Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, - ICmpInst *ICI) { - bool Changed = false; +Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, + ICmpInst *ICI) { + if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, *Builder)) + return NewSel; + + bool Changed = adjustMinMax(SI, *ICI); + ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); - // Check cases where the comparison is with a constant that - // can be adjusted to fit the min/max idiom. We may move or edit ICI - // here, so make sure the select is the only user. - if (ICI->hasOneUse()) - if (ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS)) { - switch (Pred) { - default: break; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_SGT: { - // These transformations only work for selects over integers. - IntegerType *SelectTy = dyn_cast<IntegerType>(SI.getType()); - if (!SelectTy) - break; - - Constant *AdjustedRHS; - if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_SGT) - AdjustedRHS = ConstantInt::get(CI->getContext(), CI->getValue() + 1); - else // (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) - AdjustedRHS = ConstantInt::get(CI->getContext(), CI->getValue() - 1); - - // X > C ? X : C+1 --> X < C+1 ? C+1 : X - // X < C ? X : C-1 --> X > C-1 ? C-1 : X - if ((CmpLHS == TrueVal && AdjustedRHS == FalseVal) || - (CmpLHS == FalseVal && AdjustedRHS == TrueVal)) - ; // Nothing to do here. Values match without any sign/zero extension. - - // Types do not match. Instead of calculating this with mixed types - // promote all to the larger type. This enables scalar evolution to - // analyze this expression. - else if (CmpRHS->getType()->getScalarSizeInBits() - < SelectTy->getBitWidth()) { - Constant *sextRHS = ConstantExpr::getSExt(AdjustedRHS, SelectTy); - - // X = sext x; x >s c ? X : C+1 --> X = sext x; X <s C+1 ? C+1 : X - // X = sext x; x <s c ? X : C-1 --> X = sext x; X >s C-1 ? C-1 : X - // X = sext x; x >u c ? X : C+1 --> X = sext x; X <u C+1 ? C+1 : X - // X = sext x; x <u c ? X : C-1 --> X = sext x; X >u C-1 ? C-1 : X - if (match(TrueVal, m_SExt(m_Specific(CmpLHS))) && - sextRHS == FalseVal) { - CmpLHS = TrueVal; - AdjustedRHS = sextRHS; - } else if (match(FalseVal, m_SExt(m_Specific(CmpLHS))) && - sextRHS == TrueVal) { - CmpLHS = FalseVal; - AdjustedRHS = sextRHS; - } else if (ICI->isUnsigned()) { - Constant *zextRHS = ConstantExpr::getZExt(AdjustedRHS, SelectTy); - // X = zext x; x >u c ? X : C+1 --> X = zext x; X <u C+1 ? C+1 : X - // X = zext x; x <u c ? X : C-1 --> X = zext x; X >u C-1 ? C-1 : X - // zext + signed compare cannot be changed: - // 0xff <s 0x00, but 0x00ff >s 0x0000 - if (match(TrueVal, m_ZExt(m_Specific(CmpLHS))) && - zextRHS == FalseVal) { - CmpLHS = TrueVal; - AdjustedRHS = zextRHS; - } else if (match(FalseVal, m_ZExt(m_Specific(CmpLHS))) && - zextRHS == TrueVal) { - CmpLHS = FalseVal; - AdjustedRHS = zextRHS; - } else - break; - } else - break; - } else - break; - - Pred = ICmpInst::getSwappedPredicate(Pred); - CmpRHS = AdjustedRHS; - std::swap(FalseVal, TrueVal); - ICI->setPredicate(Pred); - ICI->setOperand(0, CmpLHS); - ICI->setOperand(1, CmpRHS); - SI.setOperand(1, TrueVal); - SI.setOperand(2, FalseVal); - - // Move ICI instruction right before the select instruction. Otherwise - // the sext/zext value may be defined after the ICI instruction uses it. - ICI->moveBefore(&SI); - - Changed = true; - break; - } - } - } - // Transform (X >s -1) ? C1 : C2 --> ((X >>s 31) & (C2 - C1)) + C1 // and (X <s 0) ? C2 : C1 --> ((X >>s 31) & (C2 - C1)) + C1 // FIXME: Type and constness constraints could be lifted, but we have to @@ -623,7 +667,7 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, /// /// because Y is not live in BB1/BB2. /// -static bool CanSelectOperandBeMappingIntoPredBlock(const Value *V, +static bool canSelectOperandBeMappingIntoPredBlock(const Value *V, const SelectInst &SI) { // If the value is a non-instruction value like a constant or argument, it // can always be mapped. @@ -651,7 +695,7 @@ static bool CanSelectOperandBeMappingIntoPredBlock(const Value *V, /// We have an SPF (e.g. a min or max) of an SPF of the form: /// SPF2(SPF1(A, B), C) -Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, +Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, SelectPatternFlavor SPF1, Value *A, Value *B, Instruction &Outer, @@ -675,28 +719,24 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, } if (SPF1 == SPF2) { - if (ConstantInt *CB = dyn_cast<ConstantInt>(B)) { - if (ConstantInt *CC = dyn_cast<ConstantInt>(C)) { - const APInt &ACB = CB->getValue(); - const APInt &ACC = CC->getValue(); - - // MIN(MIN(A, 23), 97) -> MIN(A, 23) - // MAX(MAX(A, 97), 23) -> MAX(A, 97) - if ((SPF1 == SPF_UMIN && ACB.ule(ACC)) || - (SPF1 == SPF_SMIN && ACB.sle(ACC)) || - (SPF1 == SPF_UMAX && ACB.uge(ACC)) || - (SPF1 == SPF_SMAX && ACB.sge(ACC))) - return replaceInstUsesWith(Outer, Inner); - - // MIN(MIN(A, 97), 23) -> MIN(A, 23) - // MAX(MAX(A, 23), 97) -> MAX(A, 97) - if ((SPF1 == SPF_UMIN && ACB.ugt(ACC)) || - (SPF1 == SPF_SMIN && ACB.sgt(ACC)) || - (SPF1 == SPF_UMAX && ACB.ult(ACC)) || - (SPF1 == SPF_SMAX && ACB.slt(ACC))) { - Outer.replaceUsesOfWith(Inner, A); - return &Outer; - } + const APInt *CB, *CC; + if (match(B, m_APInt(CB)) && match(C, m_APInt(CC))) { + // MIN(MIN(A, 23), 97) -> MIN(A, 23) + // MAX(MAX(A, 97), 23) -> MAX(A, 97) + if ((SPF1 == SPF_UMIN && CB->ule(*CC)) || + (SPF1 == SPF_SMIN && CB->sle(*CC)) || + (SPF1 == SPF_UMAX && CB->uge(*CC)) || + (SPF1 == SPF_SMAX && CB->sge(*CC))) + return replaceInstUsesWith(Outer, Inner); + + // MIN(MIN(A, 97), 23) -> MIN(A, 23) + // MAX(MAX(A, 23), 97) -> MAX(A, 97) + if ((SPF1 == SPF_UMIN && CB->ugt(*CC)) || + (SPF1 == SPF_SMIN && CB->sgt(*CC)) || + (SPF1 == SPF_UMAX && CB->ult(*CC)) || + (SPF1 == SPF_SMAX && CB->slt(*CC))) { + Outer.replaceUsesOfWith(Inner, A); + return &Outer; } } } @@ -712,8 +752,9 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, if ((SPF1 == SPF_ABS && SPF2 == SPF_NABS) || (SPF1 == SPF_NABS && SPF2 == SPF_ABS)) { SelectInst *SI = cast<SelectInst>(Inner); - Value *NewSI = Builder->CreateSelect( - SI->getCondition(), SI->getFalseValue(), SI->getTrueValue()); + Value *NewSI = + Builder->CreateSelect(SI->getCondition(), SI->getFalseValue(), + SI->getTrueValue(), SI->getName(), SI); return replaceInstUsesWith(Outer, NewSI); } @@ -895,7 +936,7 @@ static Instruction *foldAddSubSelect(SelectInst &SI, if (AddOp != TI) std::swap(NewTrueOp, NewFalseOp); Value *NewSel = Builder.CreateSelect(CondVal, NewTrueOp, NewFalseOp, - SI.getName() + ".p"); + SI.getName() + ".p", &SI); if (SI.getType()->isFPOrFPVectorTy()) { Instruction *RI = @@ -912,6 +953,147 @@ static Instruction *foldAddSubSelect(SelectInst &SI, return nullptr; } +Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { + Instruction *ExtInst; + if (!match(Sel.getTrueValue(), m_Instruction(ExtInst)) && + !match(Sel.getFalseValue(), m_Instruction(ExtInst))) + return nullptr; + + auto ExtOpcode = ExtInst->getOpcode(); + if (ExtOpcode != Instruction::ZExt && ExtOpcode != Instruction::SExt) + return nullptr; + + // TODO: Handle larger types? That requires adjusting FoldOpIntoSelect too. + Value *X = ExtInst->getOperand(0); + Type *SmallType = X->getType(); + if (!SmallType->getScalarType()->isIntegerTy(1)) + return nullptr; + + Constant *C; + if (!match(Sel.getTrueValue(), m_Constant(C)) && + !match(Sel.getFalseValue(), m_Constant(C))) + return nullptr; + + // If the constant is the same after truncation to the smaller type and + // extension to the original type, we can narrow the select. + Value *Cond = Sel.getCondition(); + Type *SelType = Sel.getType(); + Constant *TruncC = ConstantExpr::getTrunc(C, SmallType); + Constant *ExtC = ConstantExpr::getCast(ExtOpcode, TruncC, SelType); + if (ExtC == C) { + Value *TruncCVal = cast<Value>(TruncC); + if (ExtInst == Sel.getFalseValue()) + std::swap(X, TruncCVal); + + // select Cond, (ext X), C --> ext(select Cond, X, C') + // select Cond, C, (ext X) --> ext(select Cond, C', X) + Value *NewSel = Builder->CreateSelect(Cond, X, TruncCVal, "narrow", &Sel); + return CastInst::Create(Instruction::CastOps(ExtOpcode), NewSel, SelType); + } + + // If one arm of the select is the extend of the condition, replace that arm + // with the extension of the appropriate known bool value. + if (Cond == X) { + if (ExtInst == Sel.getTrueValue()) { + // select X, (sext X), C --> select X, -1, C + // select X, (zext X), C --> select X, 1, C + Constant *One = ConstantInt::getTrue(SmallType); + Constant *AllOnesOrOne = ConstantExpr::getCast(ExtOpcode, One, SelType); + return SelectInst::Create(Cond, AllOnesOrOne, C, "", nullptr, &Sel); + } else { + // select X, C, (sext X) --> select X, C, 0 + // select X, C, (zext X) --> select X, C, 0 + Constant *Zero = ConstantInt::getNullValue(SelType); + return SelectInst::Create(Cond, C, Zero, "", nullptr, &Sel); + } + } + + return nullptr; +} + +/// Try to transform a vector select with a constant condition vector into a +/// shuffle for easier combining with other shuffles and insert/extract. +static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { + Value *CondVal = SI.getCondition(); + Constant *CondC; + if (!CondVal->getType()->isVectorTy() || !match(CondVal, m_Constant(CondC))) + return nullptr; + + unsigned NumElts = CondVal->getType()->getVectorNumElements(); + SmallVector<Constant *, 16> Mask; + Mask.reserve(NumElts); + Type *Int32Ty = Type::getInt32Ty(CondVal->getContext()); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = CondC->getAggregateElement(i); + if (!Elt) + return nullptr; + + if (Elt->isOneValue()) { + // If the select condition element is true, choose from the 1st vector. + Mask.push_back(ConstantInt::get(Int32Ty, i)); + } else if (Elt->isNullValue()) { + // If the select condition element is false, choose from the 2nd vector. + Mask.push_back(ConstantInt::get(Int32Ty, i + NumElts)); + } else if (isa<UndefValue>(Elt)) { + // If the select condition element is undef, the shuffle mask is undef. + Mask.push_back(UndefValue::get(Int32Ty)); + } else { + // Bail out on a constant expression. + return nullptr; + } + } + + return new ShuffleVectorInst(SI.getTrueValue(), SI.getFalseValue(), + ConstantVector::get(Mask)); +} + +/// Reuse bitcasted operands between a compare and select: +/// select (cmp (bitcast C), (bitcast D)), (bitcast' C), (bitcast' D) --> +/// bitcast (select (cmp (bitcast C), (bitcast D)), (bitcast C), (bitcast D)) +static Instruction *foldSelectCmpBitcasts(SelectInst &Sel, + InstCombiner::BuilderTy &Builder) { + Value *Cond = Sel.getCondition(); + Value *TVal = Sel.getTrueValue(); + Value *FVal = Sel.getFalseValue(); + + CmpInst::Predicate Pred; + Value *A, *B; + if (!match(Cond, m_Cmp(Pred, m_Value(A), m_Value(B)))) + return nullptr; + + // The select condition is a compare instruction. If the select's true/false + // values are already the same as the compare operands, there's nothing to do. + if (TVal == A || TVal == B || FVal == A || FVal == B) + return nullptr; + + Value *C, *D; + if (!match(A, m_BitCast(m_Value(C))) || !match(B, m_BitCast(m_Value(D)))) + return nullptr; + + // select (cmp (bitcast C), (bitcast D)), (bitcast TSrc), (bitcast FSrc) + Value *TSrc, *FSrc; + if (!match(TVal, m_BitCast(m_Value(TSrc))) || + !match(FVal, m_BitCast(m_Value(FSrc)))) + return nullptr; + + // If the select true/false values are *different bitcasts* of the same source + // operands, make the select operands the same as the compare operands and + // cast the result. This is the canonical select form for min/max. + Value *NewSel; + if (TSrc == C && FSrc == D) { + // select (cmp (bitcast C), (bitcast D)), (bitcast' C), (bitcast' D) --> + // bitcast (select (cmp A, B), A, B) + NewSel = Builder.CreateSelect(Cond, A, B, "", &Sel); + } else if (TSrc == D && FSrc == C) { + // select (cmp (bitcast C), (bitcast D)), (bitcast' D), (bitcast' C) --> + // bitcast (select (cmp A, B), B, A) + NewSel = Builder.CreateSelect(Cond, B, A, "", &Sel); + } else { + return nullptr; + } + return CastInst::CreateBitOrPointerCast(NewSel, Sel.getType()); +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -919,9 +1101,12 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Type *SelType = SI.getType(); if (Value *V = - SimplifySelectInst(CondVal, TrueVal, FalseVal, DL, TLI, DT, AC)) + SimplifySelectInst(CondVal, TrueVal, FalseVal, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(SI, V); + if (Instruction *I = canonicalizeSelectToShuffle(SI)) + return I; + if (SelType->getScalarType()->isIntegerTy(1) && TrueVal->getType() == CondVal->getType()) { if (match(TrueVal, m_One())) { @@ -1085,7 +1270,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // See if we are selecting two values based on a comparison of the two values. if (ICmpInst *ICI = dyn_cast<ICmpInst>(CondVal)) - if (Instruction *Result = visitSelectInstWithICmp(SI, ICI)) + if (Instruction *Result = foldSelectInstWithICmp(SI, ICI)) return Result; if (Instruction *Add = foldAddSubSelect(SI, *Builder)) @@ -1095,12 +1280,15 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { auto *TI = dyn_cast<Instruction>(TrueVal); auto *FI = dyn_cast<Instruction>(FalseVal); if (TI && FI && TI->getOpcode() == FI->getOpcode()) - if (Instruction *IV = FoldSelectOpOp(SI, TI, FI)) + if (Instruction *IV = foldSelectOpOp(SI, TI, FI)) return IV; + if (Instruction *I = foldSelectExtConst(SI)) + return I; + // See if we can fold the select into one of our operands. if (SelType->isIntOrIntVectorTy() || SelType->isFPOrFPVectorTy()) { - if (Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal)) + if (Instruction *FoldI = foldSelectIntoOp(SI, TrueVal, FalseVal)) return FoldI; Value *LHS, *RHS, *LHS2, *RHS2; @@ -1124,9 +1312,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Cmp = Builder->CreateFCmp(Pred, LHS, RHS); } - Value *NewSI = Builder->CreateCast(CastOp, - Builder->CreateSelect(Cmp, LHS, RHS), - SelType); + Value *NewSI = Builder->CreateCast( + CastOp, Builder->CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI), + SelType); return replaceInstUsesWith(SI, NewSI); } } @@ -1139,39 +1327,35 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // ABS(ABS(a)) -> ABS(a) // NABS(NABS(a)) -> NABS(a) if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor) - if (Instruction *R = FoldSPFofSPF(cast<Instruction>(LHS),SPF2,LHS2,RHS2, + if (Instruction *R = foldSPFofSPF(cast<Instruction>(LHS),SPF2,LHS2,RHS2, SI, SPF, RHS)) return R; if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2).Flavor) - if (Instruction *R = FoldSPFofSPF(cast<Instruction>(RHS),SPF2,LHS2,RHS2, + if (Instruction *R = foldSPFofSPF(cast<Instruction>(RHS),SPF2,LHS2,RHS2, SI, SPF, LHS)) return R; } // MAX(~a, ~b) -> ~MIN(a, b) - if (SPF == SPF_SMAX || SPF == SPF_UMAX) { - if (IsFreeToInvert(LHS, LHS->hasNUses(2)) && - IsFreeToInvert(RHS, RHS->hasNUses(2))) { - - // This transform adds a xor operation and that extra cost needs to be - // justified. We look for simplifications that will result from - // applying this rule: - - bool Profitable = - (LHS->hasNUses(2) && match(LHS, m_Not(m_Value()))) || - (RHS->hasNUses(2) && match(RHS, m_Not(m_Value()))) || - (SI.hasOneUse() && match(*SI.user_begin(), m_Not(m_Value()))); - - if (Profitable) { - Value *NewLHS = Builder->CreateNot(LHS); - Value *NewRHS = Builder->CreateNot(RHS); - Value *NewCmp = SPF == SPF_SMAX - ? Builder->CreateICmpSLT(NewLHS, NewRHS) - : Builder->CreateICmpULT(NewLHS, NewRHS); - Value *NewSI = - Builder->CreateNot(Builder->CreateSelect(NewCmp, NewLHS, NewRHS)); - return replaceInstUsesWith(SI, NewSI); - } + if ((SPF == SPF_SMAX || SPF == SPF_UMAX) && + IsFreeToInvert(LHS, LHS->hasNUses(2)) && + IsFreeToInvert(RHS, RHS->hasNUses(2))) { + // For this transform to be profitable, we need to eliminate at least two + // 'not' instructions if we're going to add one 'not' instruction. + int NumberOfNots = + (LHS->hasNUses(2) && match(LHS, m_Not(m_Value()))) + + (RHS->hasNUses(2) && match(RHS, m_Not(m_Value()))) + + (SI.hasOneUse() && match(*SI.user_begin(), m_Not(m_Value()))); + + if (NumberOfNots >= 2) { + Value *NewLHS = Builder->CreateNot(LHS); + Value *NewRHS = Builder->CreateNot(RHS); + Value *NewCmp = SPF == SPF_SMAX + ? Builder->CreateICmpSLT(NewLHS, NewRHS) + : Builder->CreateICmpULT(NewLHS, NewRHS); + Value *NewSI = + Builder->CreateNot(Builder->CreateSelect(NewCmp, NewLHS, NewRHS)); + return replaceInstUsesWith(SI, NewSI); } } @@ -1182,8 +1366,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // See if we can fold the select into a phi node if the condition is a select. if (isa<PHINode>(SI.getCondition())) // The true/false values have to be live in the PHI predecessor's blocks. - if (CanSelectOperandBeMappingIntoPredBlock(TrueVal, SI) && - CanSelectOperandBeMappingIntoPredBlock(FalseVal, SI)) + if (canSelectOperandBeMappingIntoPredBlock(TrueVal, SI) && + canSelectOperandBeMappingIntoPredBlock(FalseVal, SI)) if (Instruction *NV = FoldOpIntoPhi(SI)) return NV; @@ -1233,7 +1417,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return &SI; } - if (VectorType* VecTy = dyn_cast<VectorType>(SelType)) { + if (VectorType *VecTy = dyn_cast<VectorType>(SelType)) { unsigned VWidth = VecTy->getNumElements(); APInt UndefElts(VWidth, 0); APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); @@ -1266,5 +1450,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } + if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, *Builder)) + return BitCastSel; + return nullptr; } diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 08e16a7..4ff9b64 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -39,10 +39,19 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) return Res; + // (C1 shift (A add C2)) -> (C1 shift C2) shift A) + // iff A and C2 are both positive. + Value *A; + Constant *C; + if (match(Op0, m_Constant()) && match(Op1, m_Add(m_Value(A), m_Constant(C)))) + if (isKnownNonNegative(A, DL) && isKnownNonNegative(C, DL)) + return BinaryOperator::Create( + I.getOpcode(), Builder->CreateBinOp(I.getOpcode(), Op0, C), A); + // X shift (A srem B) -> X shift (A and B-1) iff B is a power of 2. // Because shifts by negative values (which could occur if A were negative) // are undefined. - Value *A; const APInt *B; + const APInt *B; if (Op1->hasOneUse() && match(Op1, m_SRem(m_Value(A), m_Power2(B)))) { // FIXME: Should this get moved into SimplifyDemandedBits by saying we don't // demand the sign bit (and many others) here?? @@ -194,8 +203,10 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, else V = IC.Builder->CreateLShr(C, NumBits); // If we got a constantexpr back, try to simplify it with TD info. - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) - V = ConstantFoldConstantExpression(CE, DL, IC.getTargetLibraryInfo()); + if (auto *C = dyn_cast<Constant>(V)) + if (auto *FoldedC = + ConstantFoldConstant(C, DL, &IC.getTargetLibraryInfo())) + V = FoldedC; return V; } @@ -317,7 +328,167 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, } } +/// Try to fold (X << C1) << C2, where the shifts are some combination of +/// shl/ashr/lshr. +static Instruction * +foldShiftByConstOfShiftByConst(BinaryOperator &I, ConstantInt *COp1, + InstCombiner::BuilderTy *Builder) { + Value *Op0 = I.getOperand(0); + uint32_t TypeBits = Op0->getType()->getScalarSizeInBits(); + + // Find out if this is a shift of a shift by a constant. + BinaryOperator *ShiftOp = dyn_cast<BinaryOperator>(Op0); + if (ShiftOp && !ShiftOp->isShift()) + ShiftOp = nullptr; + + if (ShiftOp && isa<ConstantInt>(ShiftOp->getOperand(1))) { + + // This is a constant shift of a constant shift. Be careful about hiding + // shl instructions behind bit masks. They are used to represent multiplies + // by a constant, and it is important that simple arithmetic expressions + // are still recognizable by scalar evolution. + // + // The transforms applied to shl are very similar to the transforms applied + // to mul by constant. We can be more aggressive about optimizing right + // shifts. + // + // Combinations of right and left shifts will still be optimized in + // DAGCombine where scalar evolution no longer applies. + + ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1)); + uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits); + uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits); + assert(ShiftAmt2 != 0 && "Should have been simplified earlier"); + if (ShiftAmt1 == 0) + return nullptr; // Will be simplified in the future. + Value *X = ShiftOp->getOperand(0); + + IntegerType *Ty = cast<IntegerType>(I.getType()); + + // Check for (X << c1) << c2 and (X >> c1) >> c2 + if (I.getOpcode() == ShiftOp->getOpcode()) { + uint32_t AmtSum = ShiftAmt1 + ShiftAmt2; // Fold into one big shift. + // If this is an oversized composite shift, then unsigned shifts become + // zero (handled in InstSimplify) and ashr saturates. + if (AmtSum >= TypeBits) { + if (I.getOpcode() != Instruction::AShr) + return nullptr; + AmtSum = TypeBits - 1; // Saturate to 31 for i32 ashr. + } + + return BinaryOperator::Create(I.getOpcode(), X, + ConstantInt::get(Ty, AmtSum)); + } + + if (ShiftAmt1 == ShiftAmt2) { + // If we have ((X << C) >>u C), turn this into X & (-1 >>u C). + if (I.getOpcode() == Instruction::LShr && + ShiftOp->getOpcode() == Instruction::Shl) { + APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt1)); + return BinaryOperator::CreateAnd( + X, ConstantInt::get(I.getContext(), Mask)); + } + } else if (ShiftAmt1 < ShiftAmt2) { + uint32_t ShiftDiff = ShiftAmt2 - ShiftAmt1; + + // (X >>?,exact C1) << C2 --> X << (C2-C1) + // The inexact version is deferred to DAGCombine so we don't hide shl + // behind a bit mask. + if (I.getOpcode() == Instruction::Shl && + ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) { + assert(ShiftOp->getOpcode() == Instruction::LShr || + ShiftOp->getOpcode() == Instruction::AShr); + ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); + BinaryOperator *NewShl = + BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst); + NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); + return NewShl; + } + // (X << C1) >>u C2 --> X >>u (C2-C1) & (-1 >> C2) + if (I.getOpcode() == Instruction::LShr && + ShiftOp->getOpcode() == Instruction::Shl) { + ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); + // (X <<nuw C1) >>u C2 --> X >>u (C2-C1) + if (ShiftOp->hasNoUnsignedWrap()) { + BinaryOperator *NewLShr = + BinaryOperator::Create(Instruction::LShr, X, ShiftDiffCst); + NewLShr->setIsExact(I.isExact()); + return NewLShr; + } + Value *Shift = Builder->CreateLShr(X, ShiftDiffCst); + + APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); + return BinaryOperator::CreateAnd( + Shift, ConstantInt::get(I.getContext(), Mask)); + } + + // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, + // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. + if (I.getOpcode() == Instruction::AShr && + ShiftOp->getOpcode() == Instruction::Shl) { + if (ShiftOp->hasNoSignedWrap()) { + // (X <<nsw C1) >>s C2 --> X >>s (C2-C1) + ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); + BinaryOperator *NewAShr = + BinaryOperator::Create(Instruction::AShr, X, ShiftDiffCst); + NewAShr->setIsExact(I.isExact()); + return NewAShr; + } + } + } else { + assert(ShiftAmt2 < ShiftAmt1); + uint32_t ShiftDiff = ShiftAmt1 - ShiftAmt2; + + // (X >>?exact C1) << C2 --> X >>?exact (C1-C2) + // The inexact version is deferred to DAGCombine so we don't hide shl + // behind a bit mask. + if (I.getOpcode() == Instruction::Shl && + ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) { + ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); + BinaryOperator *NewShr = + BinaryOperator::Create(ShiftOp->getOpcode(), X, ShiftDiffCst); + NewShr->setIsExact(true); + return NewShr; + } + + // (X << C1) >>u C2 --> X << (C1-C2) & (-1 >> C2) + if (I.getOpcode() == Instruction::LShr && + ShiftOp->getOpcode() == Instruction::Shl) { + ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); + if (ShiftOp->hasNoUnsignedWrap()) { + // (X <<nuw C1) >>u C2 --> X <<nuw (C1-C2) + BinaryOperator *NewShl = + BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst); + NewShl->setHasNoUnsignedWrap(true); + return NewShl; + } + Value *Shift = Builder->CreateShl(X, ShiftDiffCst); + + APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); + return BinaryOperator::CreateAnd( + Shift, ConstantInt::get(I.getContext(), Mask)); + } + + // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, + // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. + if (I.getOpcode() == Instruction::AShr && + ShiftOp->getOpcode() == Instruction::Shl) { + if (ShiftOp->hasNoSignedWrap()) { + // (X <<nsw C1) >>s C2 --> X <<nsw (C1-C2) + ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); + BinaryOperator *NewShl = + BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst); + NewShl->setHasNoSignedWrap(true); + return NewShl; + } + } + } + } + + return nullptr; +} Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I) { @@ -359,13 +530,8 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, return BinaryOperator::CreateMul(BO->getOperand(0), ConstantExpr::getShl(BOOp, Op1)); - // Try to fold constant and into select arguments. - if (SelectInst *SI = dyn_cast<SelectInst>(Op0)) - if (Instruction *R = FoldOpIntoSelect(I, SI)) - return R; - if (isa<PHINode>(Op0)) - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + if (Instruction *FoldedShift = foldOpWithConstantIntoOperand(I)) + return FoldedShift; // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2)) if (TruncInst *TI = dyn_cast<TruncInst>(Op0)) { @@ -455,9 +621,9 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, V1->getName()+".mask"); return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM); } + LLVM_FALLTHROUGH; } - // FALL THROUGH. case Instruction::Sub: { // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && @@ -539,157 +705,9 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, } } - // Find out if this is a shift of a shift by a constant. - BinaryOperator *ShiftOp = dyn_cast<BinaryOperator>(Op0); - if (ShiftOp && !ShiftOp->isShift()) - ShiftOp = nullptr; - - if (ShiftOp && isa<ConstantInt>(ShiftOp->getOperand(1))) { - - // This is a constant shift of a constant shift. Be careful about hiding - // shl instructions behind bit masks. They are used to represent multiplies - // by a constant, and it is important that simple arithmetic expressions - // are still recognizable by scalar evolution. - // - // The transforms applied to shl are very similar to the transforms applied - // to mul by constant. We can be more aggressive about optimizing right - // shifts. - // - // Combinations of right and left shifts will still be optimized in - // DAGCombine where scalar evolution no longer applies. - - ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1)); - uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits); - uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits); - assert(ShiftAmt2 != 0 && "Should have been simplified earlier"); - if (ShiftAmt1 == 0) return nullptr; // Will be simplified in the future. - Value *X = ShiftOp->getOperand(0); - - IntegerType *Ty = cast<IntegerType>(I.getType()); - - // Check for (X << c1) << c2 and (X >> c1) >> c2 - if (I.getOpcode() == ShiftOp->getOpcode()) { - uint32_t AmtSum = ShiftAmt1+ShiftAmt2; // Fold into one big shift. - // If this is oversized composite shift, then unsigned shifts get 0, ashr - // saturates. - if (AmtSum >= TypeBits) { - if (I.getOpcode() != Instruction::AShr) - return replaceInstUsesWith(I, Constant::getNullValue(I.getType())); - AmtSum = TypeBits-1; // Saturate to 31 for i32 ashr. - } - - return BinaryOperator::Create(I.getOpcode(), X, - ConstantInt::get(Ty, AmtSum)); - } - - if (ShiftAmt1 == ShiftAmt2) { - // If we have ((X << C) >>u C), turn this into X & (-1 >>u C). - if (I.getOpcode() == Instruction::LShr && - ShiftOp->getOpcode() == Instruction::Shl) { - APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt1)); - return BinaryOperator::CreateAnd(X, - ConstantInt::get(I.getContext(), Mask)); - } - } else if (ShiftAmt1 < ShiftAmt2) { - uint32_t ShiftDiff = ShiftAmt2-ShiftAmt1; - - // (X >>?,exact C1) << C2 --> X << (C2-C1) - // The inexact version is deferred to DAGCombine so we don't hide shl - // behind a bit mask. - if (I.getOpcode() == Instruction::Shl && - ShiftOp->getOpcode() != Instruction::Shl && - ShiftOp->isExact()) { - assert(ShiftOp->getOpcode() == Instruction::LShr || - ShiftOp->getOpcode() == Instruction::AShr); - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewShl = BinaryOperator::Create(Instruction::Shl, - X, ShiftDiffCst); - NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); - NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); - return NewShl; - } - - // (X << C1) >>u C2 --> X >>u (C2-C1) & (-1 >> C2) - if (I.getOpcode() == Instruction::LShr && - ShiftOp->getOpcode() == Instruction::Shl) { - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - // (X <<nuw C1) >>u C2 --> X >>u (C2-C1) - if (ShiftOp->hasNoUnsignedWrap()) { - BinaryOperator *NewLShr = BinaryOperator::Create(Instruction::LShr, - X, ShiftDiffCst); - NewLShr->setIsExact(I.isExact()); - return NewLShr; - } - Value *Shift = Builder->CreateLShr(X, ShiftDiffCst); - - APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); - return BinaryOperator::CreateAnd(Shift, - ConstantInt::get(I.getContext(),Mask)); - } - - // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, - // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. - if (I.getOpcode() == Instruction::AShr && - ShiftOp->getOpcode() == Instruction::Shl) { - if (ShiftOp->hasNoSignedWrap()) { - // (X <<nsw C1) >>s C2 --> X >>s (C2-C1) - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewAShr = BinaryOperator::Create(Instruction::AShr, - X, ShiftDiffCst); - NewAShr->setIsExact(I.isExact()); - return NewAShr; - } - } - } else { - assert(ShiftAmt2 < ShiftAmt1); - uint32_t ShiftDiff = ShiftAmt1-ShiftAmt2; - - // (X >>?exact C1) << C2 --> X >>?exact (C1-C2) - // The inexact version is deferred to DAGCombine so we don't hide shl - // behind a bit mask. - if (I.getOpcode() == Instruction::Shl && - ShiftOp->getOpcode() != Instruction::Shl && - ShiftOp->isExact()) { - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewShr = BinaryOperator::Create(ShiftOp->getOpcode(), - X, ShiftDiffCst); - NewShr->setIsExact(true); - return NewShr; - } + if (Instruction *Folded = foldShiftByConstOfShiftByConst(I, COp1, Builder)) + return Folded; - // (X << C1) >>u C2 --> X << (C1-C2) & (-1 >> C2) - if (I.getOpcode() == Instruction::LShr && - ShiftOp->getOpcode() == Instruction::Shl) { - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - if (ShiftOp->hasNoUnsignedWrap()) { - // (X <<nuw C1) >>u C2 --> X <<nuw (C1-C2) - BinaryOperator *NewShl = BinaryOperator::Create(Instruction::Shl, - X, ShiftDiffCst); - NewShl->setHasNoUnsignedWrap(true); - return NewShl; - } - Value *Shift = Builder->CreateShl(X, ShiftDiffCst); - - APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); - return BinaryOperator::CreateAnd(Shift, - ConstantInt::get(I.getContext(),Mask)); - } - - // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, - // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. - if (I.getOpcode() == Instruction::AShr && - ShiftOp->getOpcode() == Instruction::Shl) { - if (ShiftOp->hasNoSignedWrap()) { - // (X <<nsw C1) >>s C2 --> X <<nsw (C1-C2) - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewShl = BinaryOperator::Create(Instruction::Shl, - X, ShiftDiffCst); - NewShl->setHasNoSignedWrap(true); - return NewShl; - } - } - } - } return nullptr; } @@ -699,7 +717,7 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) + I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (Instruction *V = commonShiftTransforms(I)) @@ -708,6 +726,25 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { if (ConstantInt *Op1C = dyn_cast<ConstantInt>(I.getOperand(1))) { unsigned ShAmt = Op1C->getZExtValue(); + // Turn: + // %zext = zext i32 %V to i64 + // %res = shl i64 %V, 8 + // + // Into: + // %shl = shl i32 %V, 8 + // %res = zext i32 %shl to i64 + // + // This is only valid if %V would have zeros shifted out. + if (auto *ZI = dyn_cast<ZExtInst>(I.getOperand(0))) { + unsigned SrcBitWidth = ZI->getSrcTy()->getScalarSizeInBits(); + if (ShAmt < SrcBitWidth && + MaskedValueIsZero(ZI->getOperand(0), + APInt::getHighBitsSet(SrcBitWidth, ShAmt), 0, &I)) { + auto *Shl = Builder->CreateShl(ZI->getOperand(0), ShAmt); + return new ZExtInst(Shl, I.getType()); + } + } + // If the shifted-out value is known-zero, then this is a NUW shift. if (!I.hasNoUnsignedWrap() && MaskedValueIsZero(I.getOperand(0), @@ -740,7 +777,7 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), - DL, TLI, DT, AC)) + DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) @@ -784,7 +821,7 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), - DL, TLI, DT, AC)) + DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index f3268d2..8b930bd 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -981,6 +981,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, bool MadeChange = false; APInt UndefElts2(VWidth, 0); + APInt UndefElts3(VWidth, 0); Value *TmpV; switch (I->getOpcode()) { default: break; @@ -1020,8 +1021,8 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, } case Instruction::ShuffleVector: { ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I); - uint64_t LHSVWidth = - cast<VectorType>(Shuffle->getOperand(0)->getType())->getNumElements(); + unsigned LHSVWidth = + Shuffle->getOperand(0)->getType()->getVectorNumElements(); APInt LeftDemanded(LHSVWidth, 0), RightDemanded(LHSVWidth, 0); for (unsigned i = 0; i < VWidth; i++) { if (DemandedElts[i]) { @@ -1037,17 +1038,21 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, } } - APInt UndefElts4(LHSVWidth, 0); + APInt LHSUndefElts(LHSVWidth, 0); TmpV = SimplifyDemandedVectorElts(I->getOperand(0), LeftDemanded, - UndefElts4, Depth + 1); + LHSUndefElts, Depth + 1); if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } - APInt UndefElts3(LHSVWidth, 0); + APInt RHSUndefElts(LHSVWidth, 0); TmpV = SimplifyDemandedVectorElts(I->getOperand(1), RightDemanded, - UndefElts3, Depth + 1); + RHSUndefElts, Depth + 1); if (TmpV) { I->setOperand(1, TmpV); MadeChange = true; } bool NewUndefElts = false; + unsigned LHSIdx = -1u, LHSValIdx = -1u; + unsigned RHSIdx = -1u, RHSValIdx = -1u; + bool LHSUniform = true; + bool RHSUniform = true; for (unsigned i = 0; i < VWidth; i++) { unsigned MaskVal = Shuffle->getMaskValue(i); if (MaskVal == -1u) { @@ -1056,18 +1061,59 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, NewUndefElts = true; UndefElts.setBit(i); } else if (MaskVal < LHSVWidth) { - if (UndefElts4[MaskVal]) { + if (LHSUndefElts[MaskVal]) { NewUndefElts = true; UndefElts.setBit(i); + } else { + LHSIdx = LHSIdx == -1u ? i : LHSVWidth; + LHSValIdx = LHSValIdx == -1u ? MaskVal : LHSVWidth; + LHSUniform = LHSUniform && (MaskVal == i); } } else { - if (UndefElts3[MaskVal - LHSVWidth]) { + if (RHSUndefElts[MaskVal - LHSVWidth]) { NewUndefElts = true; UndefElts.setBit(i); + } else { + RHSIdx = RHSIdx == -1u ? i : LHSVWidth; + RHSValIdx = RHSValIdx == -1u ? MaskVal - LHSVWidth : LHSVWidth; + RHSUniform = RHSUniform && (MaskVal - LHSVWidth == i); } } } + // Try to transform shuffle with constant vector and single element from + // this constant vector to single insertelement instruction. + // shufflevector V, C, <v1, v2, .., ci, .., vm> -> + // insertelement V, C[ci], ci-n + if (LHSVWidth == Shuffle->getType()->getNumElements()) { + Value *Op = nullptr; + Constant *Value = nullptr; + unsigned Idx = -1u; + + // Find constant vector with the single element in shuffle (LHS or RHS). + if (LHSIdx < LHSVWidth && RHSUniform) { + if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(0))) { + Op = Shuffle->getOperand(1); + Value = CV->getOperand(LHSValIdx); + Idx = LHSIdx; + } + } + if (RHSIdx < LHSVWidth && LHSUniform) { + if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(1))) { + Op = Shuffle->getOperand(0); + Value = CV->getOperand(RHSValIdx); + Idx = RHSIdx; + } + } + // Found constant vector with single element - convert to insertelement. + if (Op && Value) { + Instruction *New = InsertElementInst::Create( + Op, Value, ConstantInt::get(Type::getInt32Ty(I->getContext()), Idx), + Shuffle->getName()); + InsertNewInstWith(New, *Shuffle); + return New; + } + } if (NewUndefElts) { // Add additional discovered undefs. SmallVector<Constant*, 16> Elts; @@ -1209,114 +1255,223 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, switch (II->getIntrinsicID()) { default: break; + case Intrinsic::x86_xop_vfrcz_ss: + case Intrinsic::x86_xop_vfrcz_sd: + // The instructions for these intrinsics are speced to zero upper bits not + // pass them through like other scalar intrinsics. So we shouldn't just + // use Arg0 if DemandedElts[0] is clear like we do for other intrinsics. + // Instead we should return a zero vector. + if (!DemandedElts[0]) { + Worklist.Add(II); + return ConstantAggregateZero::get(II->getType()); + } + + // Only the lower element is used. + DemandedElts = 1; + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, + UndefElts, Depth + 1); + if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + + // Only the lower element is undefined. The high elements are zero. + UndefElts = UndefElts[0]; + break; + // Unary scalar-as-vector operations that work column-wise. case Intrinsic::x86_sse_rcp_ss: case Intrinsic::x86_sse_rsqrt_ss: case Intrinsic::x86_sse_sqrt_ss: case Intrinsic::x86_sse2_sqrt_sd: - case Intrinsic::x86_xop_vfrcz_ss: - case Intrinsic::x86_xop_vfrcz_sd: TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, UndefElts, Depth + 1); if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } // If lowest element of a scalar op isn't used then use Arg0. - if (DemandedElts.getLoBits(1) != 1) + if (!DemandedElts[0]) { + Worklist.Add(II); return II->getArgOperand(0); + } // TODO: If only low elt lower SQRT to FSQRT (with rounding/exceptions // checks). break; - // Binary scalar-as-vector operations that work column-wise. A dest element - // is a function of the corresponding input elements from the two inputs. - case Intrinsic::x86_sse_add_ss: - case Intrinsic::x86_sse_sub_ss: - case Intrinsic::x86_sse_mul_ss: - case Intrinsic::x86_sse_div_ss: + // Binary scalar-as-vector operations that work column-wise. The high + // elements come from operand 0. The low element is a function of both + // operands. case Intrinsic::x86_sse_min_ss: case Intrinsic::x86_sse_max_ss: case Intrinsic::x86_sse_cmp_ss: - case Intrinsic::x86_sse2_add_sd: - case Intrinsic::x86_sse2_sub_sd: - case Intrinsic::x86_sse2_mul_sd: - case Intrinsic::x86_sse2_div_sd: case Intrinsic::x86_sse2_min_sd: case Intrinsic::x86_sse2_max_sd: - case Intrinsic::x86_sse2_cmp_sd: - case Intrinsic::x86_sse41_round_ss: - case Intrinsic::x86_sse41_round_sd: + case Intrinsic::x86_sse2_cmp_sd: { TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, UndefElts, Depth + 1); if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + + // If lowest element of a scalar op isn't used then use Arg0. + if (!DemandedElts[0]) { + Worklist.Add(II); + return II->getArgOperand(0); + } + + // Only lower element is used for operand 1. + DemandedElts = 1; TmpV = SimplifyDemandedVectorElts(II->getArgOperand(1), DemandedElts, UndefElts2, Depth + 1); if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } - // If only the low elt is demanded and this is a scalarizable intrinsic, - // scalarize it now. - if (DemandedElts == 1) { - switch (II->getIntrinsicID()) { - default: break; - case Intrinsic::x86_sse_add_ss: - case Intrinsic::x86_sse_sub_ss: - case Intrinsic::x86_sse_mul_ss: - case Intrinsic::x86_sse_div_ss: - case Intrinsic::x86_sse2_add_sd: - case Intrinsic::x86_sse2_sub_sd: - case Intrinsic::x86_sse2_mul_sd: - case Intrinsic::x86_sse2_div_sd: - // TODO: Lower MIN/MAX/etc. - Value *LHS = II->getArgOperand(0); - Value *RHS = II->getArgOperand(1); - // Extract the element as scalars. - LHS = InsertNewInstWith(ExtractElementInst::Create(LHS, - ConstantInt::get(Type::getInt32Ty(I->getContext()), 0U)), *II); - RHS = InsertNewInstWith(ExtractElementInst::Create(RHS, - ConstantInt::get(Type::getInt32Ty(I->getContext()), 0U)), *II); - - switch (II->getIntrinsicID()) { - default: llvm_unreachable("Case stmts out of sync!"); - case Intrinsic::x86_sse_add_ss: - case Intrinsic::x86_sse2_add_sd: - TmpV = InsertNewInstWith(BinaryOperator::CreateFAdd(LHS, RHS, - II->getName()), *II); - break; - case Intrinsic::x86_sse_sub_ss: - case Intrinsic::x86_sse2_sub_sd: - TmpV = InsertNewInstWith(BinaryOperator::CreateFSub(LHS, RHS, - II->getName()), *II); - break; - case Intrinsic::x86_sse_mul_ss: - case Intrinsic::x86_sse2_mul_sd: - TmpV = InsertNewInstWith(BinaryOperator::CreateFMul(LHS, RHS, - II->getName()), *II); - break; - case Intrinsic::x86_sse_div_ss: - case Intrinsic::x86_sse2_div_sd: - TmpV = InsertNewInstWith(BinaryOperator::CreateFDiv(LHS, RHS, - II->getName()), *II); - break; - } - - Instruction *New = - InsertElementInst::Create( - UndefValue::get(II->getType()), TmpV, - ConstantInt::get(Type::getInt32Ty(I->getContext()), 0U, false), - II->getName()); - InsertNewInstWith(New, *II); - return New; - } + // Lower element is undefined if both lower elements are undefined. + // Consider things like undef&0. The result is known zero, not undef. + if (!UndefElts2[0]) + UndefElts.clearBit(0); + + break; + } + + // Binary scalar-as-vector operations that work column-wise. The high + // elements come from operand 0 and the low element comes from operand 1. + case Intrinsic::x86_sse41_round_ss: + case Intrinsic::x86_sse41_round_sd: { + // Don't use the low element of operand 0. + APInt DemandedElts2 = DemandedElts; + DemandedElts2.clearBit(0); + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts2, + UndefElts, Depth + 1); + if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + + // If lowest element of a scalar op isn't used then use Arg0. + if (!DemandedElts[0]) { + Worklist.Add(II); + return II->getArgOperand(0); } + // Only lower element is used for operand 1. + DemandedElts = 1; + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(1), DemandedElts, + UndefElts2, Depth + 1); + if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } + + // Take the high undef elements from operand 0 and take the lower element + // from operand 1. + UndefElts.clearBit(0); + UndefElts |= UndefElts2[0]; + break; + } + + // Three input scalar-as-vector operations that work column-wise. The high + // elements come from operand 0 and the low element is a function of all + // three inputs. + case Intrinsic::x86_avx512_mask_add_ss_round: + case Intrinsic::x86_avx512_mask_div_ss_round: + case Intrinsic::x86_avx512_mask_mul_ss_round: + case Intrinsic::x86_avx512_mask_sub_ss_round: + case Intrinsic::x86_avx512_mask_max_ss_round: + case Intrinsic::x86_avx512_mask_min_ss_round: + case Intrinsic::x86_avx512_mask_add_sd_round: + case Intrinsic::x86_avx512_mask_div_sd_round: + case Intrinsic::x86_avx512_mask_mul_sd_round: + case Intrinsic::x86_avx512_mask_sub_sd_round: + case Intrinsic::x86_avx512_mask_max_sd_round: + case Intrinsic::x86_avx512_mask_min_sd_round: + case Intrinsic::x86_fma_vfmadd_ss: + case Intrinsic::x86_fma_vfmsub_ss: + case Intrinsic::x86_fma_vfnmadd_ss: + case Intrinsic::x86_fma_vfnmsub_ss: + case Intrinsic::x86_fma_vfmadd_sd: + case Intrinsic::x86_fma_vfmsub_sd: + case Intrinsic::x86_fma_vfnmadd_sd: + case Intrinsic::x86_fma_vfnmsub_sd: + case Intrinsic::x86_avx512_mask_vfmadd_ss: + case Intrinsic::x86_avx512_mask_vfmadd_sd: + case Intrinsic::x86_avx512_maskz_vfmadd_ss: + case Intrinsic::x86_avx512_maskz_vfmadd_sd: + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, + UndefElts, Depth + 1); + if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + // If lowest element of a scalar op isn't used then use Arg0. - if (DemandedElts.getLoBits(1) != 1) + if (!DemandedElts[0]) { + Worklist.Add(II); return II->getArgOperand(0); + } + + // Only lower element is used for operand 1 and 2. + DemandedElts = 1; + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(1), DemandedElts, + UndefElts2, Depth + 1); + if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(2), DemandedElts, + UndefElts3, Depth + 1); + if (TmpV) { II->setArgOperand(2, TmpV); MadeChange = true; } + + // Lower element is undefined if all three lower elements are undefined. + // Consider things like undef&0. The result is known zero, not undef. + if (!UndefElts2[0] || !UndefElts3[0]) + UndefElts.clearBit(0); - // Output elements are undefined if both are undefined. Consider things - // like undef&0. The result is known zero, not undef. - UndefElts &= UndefElts2; break; + case Intrinsic::x86_avx512_mask3_vfmadd_ss: + case Intrinsic::x86_avx512_mask3_vfmadd_sd: + case Intrinsic::x86_avx512_mask3_vfmsub_ss: + case Intrinsic::x86_avx512_mask3_vfmsub_sd: + case Intrinsic::x86_avx512_mask3_vfnmsub_ss: + case Intrinsic::x86_avx512_mask3_vfnmsub_sd: + // These intrinsics get the passthru bits from operand 2. + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(2), DemandedElts, + UndefElts, Depth + 1); + if (TmpV) { II->setArgOperand(2, TmpV); MadeChange = true; } + + // If lowest element of a scalar op isn't used then use Arg2. + if (!DemandedElts[0]) { + Worklist.Add(II); + return II->getArgOperand(2); + } + + // Only lower element is used for operand 0 and 1. + DemandedElts = 1; + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, + UndefElts2, Depth + 1); + if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + TmpV = SimplifyDemandedVectorElts(II->getArgOperand(1), DemandedElts, + UndefElts3, Depth + 1); + if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } + + // Lower element is undefined if all three lower elements are undefined. + // Consider things like undef&0. The result is known zero, not undef. + if (!UndefElts2[0] || !UndefElts3[0]) + UndefElts.clearBit(0); + + break; + + case Intrinsic::x86_sse2_pmulu_dq: + case Intrinsic::x86_sse41_pmuldq: + case Intrinsic::x86_avx2_pmul_dq: + case Intrinsic::x86_avx2_pmulu_dq: + case Intrinsic::x86_avx512_pmul_dq_512: + case Intrinsic::x86_avx512_pmulu_dq_512: { + Value *Op0 = II->getArgOperand(0); + Value *Op1 = II->getArgOperand(1); + unsigned InnerVWidth = Op0->getType()->getVectorNumElements(); + assert((VWidth * 2) == InnerVWidth && "Unexpected input size"); + + APInt InnerDemandedElts(InnerVWidth, 0); + for (unsigned i = 0; i != VWidth; ++i) + if (DemandedElts[i]) + InnerDemandedElts.setBit(i * 2); + + UndefElts2 = APInt(InnerVWidth, 0); + TmpV = SimplifyDemandedVectorElts(Op0, InnerDemandedElts, UndefElts2, + Depth + 1); + if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + + UndefElts3 = APInt(InnerVWidth, 0); + TmpV = SimplifyDemandedVectorElts(Op1, InnerDemandedElts, UndefElts3, + Depth + 1); + if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } + + break; + } + // SSE4A instructions leave the upper 64-bits of the 128-bit result // in an undefined state. case Intrinsic::x86_sse4a_extrq: diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index a761387..b2477f6 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -145,7 +145,7 @@ Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) { Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { if (Value *V = SimplifyExtractElementInst( - EI.getVectorOperand(), EI.getIndexOperand(), DL, TLI, DT, AC)) + EI.getVectorOperand(), EI.getIndexOperand(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(EI, V); // If vector val is constant with all elements the same, replace EI with @@ -413,6 +413,14 @@ static void replaceExtractElements(InsertElementInst *InsElt, if (InsertionBlock != InsElt->getParent()) return; + // TODO: This restriction matches the check in visitInsertElementInst() and + // prevents an infinite loop caused by not turning the extract/insert pair + // into a shuffle. We really should not need either check, but we're lacking + // folds for shufflevectors because we're afraid to generate shuffle masks + // that the backend can't handle. + if (InsElt->hasOneUse() && isa<InsertElementInst>(InsElt->user_back())) + return; + auto *WideVec = new ShuffleVectorInst(ExtVecOp, UndefValue::get(ExtVecType), ConstantVector::get(ExtendMask)); @@ -452,7 +460,7 @@ static ShuffleOps collectShuffleElements(Value *V, Value *PermittedRHS, InstCombiner &IC) { assert(V->getType()->isVectorTy() && "Invalid shuffle!"); - unsigned NumElts = cast<VectorType>(V->getType())->getNumElements(); + unsigned NumElts = V->getType()->getVectorNumElements(); if (isa<UndefValue>(V)) { Mask.assign(NumElts, UndefValue::get(Type::getInt32Ty(V->getContext()))); @@ -566,6 +574,176 @@ Instruction *InstCombiner::visitInsertValueInst(InsertValueInst &I) { return nullptr; } +static bool isShuffleEquivalentToSelect(ShuffleVectorInst &Shuf) { + int MaskSize = Shuf.getMask()->getType()->getVectorNumElements(); + int VecSize = Shuf.getOperand(0)->getType()->getVectorNumElements(); + + // A vector select does not change the size of the operands. + if (MaskSize != VecSize) + return false; + + // Each mask element must be undefined or choose a vector element from one of + // the source operands without crossing vector lanes. + for (int i = 0; i != MaskSize; ++i) { + int Elt = Shuf.getMaskValue(i); + if (Elt != -1 && Elt != i && Elt != i + VecSize) + return false; + } + + return true; +} + +// Turn a chain of inserts that splats a value into a canonical insert + shuffle +// splat. That is: +// insertelt(insertelt(insertelt(insertelt X, %k, 0), %k, 1), %k, 2) ... -> +// shufflevector(insertelt(X, %k, 0), undef, zero) +static Instruction *foldInsSequenceIntoBroadcast(InsertElementInst &InsElt) { + // We are interested in the last insert in a chain. So, if this insert + // has a single user, and that user is an insert, bail. + if (InsElt.hasOneUse() && isa<InsertElementInst>(InsElt.user_back())) + return nullptr; + + VectorType *VT = cast<VectorType>(InsElt.getType()); + int NumElements = VT->getNumElements(); + + // Do not try to do this for a one-element vector, since that's a nop, + // and will cause an inf-loop. + if (NumElements == 1) + return nullptr; + + Value *SplatVal = InsElt.getOperand(1); + InsertElementInst *CurrIE = &InsElt; + SmallVector<bool, 16> ElementPresent(NumElements, false); + + // Walk the chain backwards, keeping track of which indices we inserted into, + // until we hit something that isn't an insert of the splatted value. + while (CurrIE) { + ConstantInt *Idx = dyn_cast<ConstantInt>(CurrIE->getOperand(2)); + if (!Idx || CurrIE->getOperand(1) != SplatVal) + return nullptr; + + // Check none of the intermediate steps have any additional uses. + if ((CurrIE != &InsElt) && !CurrIE->hasOneUse()) + return nullptr; + + ElementPresent[Idx->getZExtValue()] = true; + CurrIE = dyn_cast<InsertElementInst>(CurrIE->getOperand(0)); + } + + // Make sure we've seen an insert into every element. + if (llvm::any_of(ElementPresent, [](bool Present) { return !Present; })) + return nullptr; + + // All right, create the insert + shuffle. + Instruction *InsertFirst = InsertElementInst::Create( + UndefValue::get(VT), SplatVal, + ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), 0), "", &InsElt); + + Constant *ZeroMask = ConstantAggregateZero::get( + VectorType::get(Type::getInt32Ty(InsElt.getContext()), NumElements)); + + return new ShuffleVectorInst(InsertFirst, UndefValue::get(VT), ZeroMask); +} + +/// insertelt (shufflevector X, CVec, Mask|insertelt X, C1, CIndex1), C, CIndex +/// --> shufflevector X, CVec', Mask' +static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { + auto *Inst = dyn_cast<Instruction>(InsElt.getOperand(0)); + // Bail out if the parent has more than one use. In that case, we'd be + // replacing the insertelt with a shuffle, and that's not a clear win. + if (!Inst || !Inst->hasOneUse()) + return nullptr; + if (auto *Shuf = dyn_cast<ShuffleVectorInst>(InsElt.getOperand(0))) { + // The shuffle must have a constant vector operand. The insertelt must have + // a constant scalar being inserted at a constant position in the vector. + Constant *ShufConstVec, *InsEltScalar; + uint64_t InsEltIndex; + if (!match(Shuf->getOperand(1), m_Constant(ShufConstVec)) || + !match(InsElt.getOperand(1), m_Constant(InsEltScalar)) || + !match(InsElt.getOperand(2), m_ConstantInt(InsEltIndex))) + return nullptr; + + // Adding an element to an arbitrary shuffle could be expensive, but a + // shuffle that selects elements from vectors without crossing lanes is + // assumed cheap. + // If we're just adding a constant into that shuffle, it will still be + // cheap. + if (!isShuffleEquivalentToSelect(*Shuf)) + return nullptr; + + // From the above 'select' check, we know that the mask has the same number + // of elements as the vector input operands. We also know that each constant + // input element is used in its lane and can not be used more than once by + // the shuffle. Therefore, replace the constant in the shuffle's constant + // vector with the insertelt constant. Replace the constant in the shuffle's + // mask vector with the insertelt index plus the length of the vector + // (because the constant vector operand of a shuffle is always the 2nd + // operand). + Constant *Mask = Shuf->getMask(); + unsigned NumElts = Mask->getType()->getVectorNumElements(); + SmallVector<Constant *, 16> NewShufElts(NumElts); + SmallVector<Constant *, 16> NewMaskElts(NumElts); + for (unsigned I = 0; I != NumElts; ++I) { + if (I == InsEltIndex) { + NewShufElts[I] = InsEltScalar; + Type *Int32Ty = Type::getInt32Ty(Shuf->getContext()); + NewMaskElts[I] = ConstantInt::get(Int32Ty, InsEltIndex + NumElts); + } else { + // Copy over the existing values. + NewShufElts[I] = ShufConstVec->getAggregateElement(I); + NewMaskElts[I] = Mask->getAggregateElement(I); + } + } + + // Create new operands for a shuffle that includes the constant of the + // original insertelt. The old shuffle will be dead now. + return new ShuffleVectorInst(Shuf->getOperand(0), + ConstantVector::get(NewShufElts), + ConstantVector::get(NewMaskElts)); + } else if (auto *IEI = dyn_cast<InsertElementInst>(Inst)) { + // Transform sequences of insertelements ops with constant data/indexes into + // a single shuffle op. + unsigned NumElts = InsElt.getType()->getNumElements(); + + uint64_t InsertIdx[2]; + Constant *Val[2]; + if (!match(InsElt.getOperand(2), m_ConstantInt(InsertIdx[0])) || + !match(InsElt.getOperand(1), m_Constant(Val[0])) || + !match(IEI->getOperand(2), m_ConstantInt(InsertIdx[1])) || + !match(IEI->getOperand(1), m_Constant(Val[1]))) + return nullptr; + SmallVector<Constant *, 16> Values(NumElts); + SmallVector<Constant *, 16> Mask(NumElts); + auto ValI = std::begin(Val); + // Generate new constant vector and mask. + // We have 2 values/masks from the insertelements instructions. Insert them + // into new value/mask vectors. + for (uint64_t I : InsertIdx) { + if (!Values[I]) { + assert(!Mask[I]); + Values[I] = *ValI; + Mask[I] = ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), + NumElts + I); + } + ++ValI; + } + // Remaining values are filled with 'undef' values. + for (unsigned I = 0; I < NumElts; ++I) { + if (!Values[I]) { + assert(!Mask[I]); + Values[I] = UndefValue::get(InsElt.getType()->getElementType()); + Mask[I] = ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), I); + } + } + // Create new operands for a shuffle that includes the constant of the + // original insertelt. + return new ShuffleVectorInst(IEI->getOperand(0), + ConstantVector::get(Values), + ConstantVector::get(Mask)); + } + return nullptr; +} + Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { Value *VecOp = IE.getOperand(0); Value *ScalarOp = IE.getOperand(1); @@ -616,7 +794,7 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { } } - unsigned VWidth = cast<VectorType>(VecOp->getType())->getNumElements(); + unsigned VWidth = VecOp->getType()->getVectorNumElements(); APInt UndefElts(VWidth, 0); APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&IE, AllOnesEltMask, UndefElts)) { @@ -625,6 +803,14 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { return &IE; } + if (Instruction *Shuf = foldConstantInsEltIntoShuffle(IE)) + return Shuf; + + // Turn a sequence of inserts that broadcasts a scalar into a single + // insert + shufflevector. + if (Instruction *Broadcast = foldInsSequenceIntoBroadcast(IE)) + return Broadcast; + return nullptr; } @@ -903,8 +1089,7 @@ static void recognizeIdentityMask(const SmallVectorImpl<int> &Mask, // +--+--+--+--+ static bool isShuffleExtractingFromLHS(ShuffleVectorInst &SVI, SmallVector<int, 16> &Mask) { - unsigned LHSElems = - cast<VectorType>(SVI.getOperand(0)->getType())->getNumElements(); + unsigned LHSElems = SVI.getOperand(0)->getType()->getVectorNumElements(); unsigned MaskElems = Mask.size(); unsigned BegIdx = Mask.front(); unsigned EndIdx = Mask.back(); @@ -928,7 +1113,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (isa<UndefValue>(SVI.getOperand(2))) return replaceInstUsesWith(SVI, UndefValue::get(SVI.getType())); - unsigned VWidth = cast<VectorType>(SVI.getType())->getNumElements(); + unsigned VWidth = SVI.getType()->getVectorNumElements(); APInt UndefElts(VWidth, 0); APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); @@ -940,7 +1125,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { MadeChange = true; } - unsigned LHSWidth = cast<VectorType>(LHS->getType())->getNumElements(); + unsigned LHSWidth = LHS->getType()->getVectorNumElements(); // Canonicalize shuffle(x ,x,mask) -> shuffle(x, undef,mask') // Canonicalize shuffle(undef,x,mask) -> shuffle(x, undef,mask'). @@ -1143,11 +1328,11 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (LHSShuffle) { LHSOp0 = LHSShuffle->getOperand(0); LHSOp1 = LHSShuffle->getOperand(1); - LHSOp0Width = cast<VectorType>(LHSOp0->getType())->getNumElements(); + LHSOp0Width = LHSOp0->getType()->getVectorNumElements(); } if (RHSShuffle) { RHSOp0 = RHSShuffle->getOperand(0); - RHSOp0Width = cast<VectorType>(RHSOp0->getType())->getNumElements(); + RHSOp0Width = RHSOp0->getType()->getVectorNumElements(); } Value* newLHS = LHS; Value* newRHS = RHS; diff --git a/contrib/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/contrib/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 377ccb9..27fc34d 100644 --- a/contrib/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/contrib/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -177,11 +177,10 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1) { return false; // TODO: Enhance logic for other BinOps and remove this check. - auto AssocOpcode = BinOp1->getOpcode(); - if (AssocOpcode != Instruction::Xor && AssocOpcode != Instruction::And && - AssocOpcode != Instruction::Or) + if (!BinOp1->isBitwiseLogicOp()) return false; + auto AssocOpcode = BinOp1->getOpcode(); auto *BinOp2 = dyn_cast<BinaryOperator>(Cast->getOperand(0)); if (!BinOp2 || !BinOp2->hasOneUse() || BinOp2->getOpcode() != AssocOpcode) return false; @@ -684,14 +683,14 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { if (SI0->getCondition() == SI1->getCondition()) { Value *SI = nullptr; if (Value *V = SimplifyBinOp(TopLevelOpcode, SI0->getFalseValue(), - SI1->getFalseValue(), DL, TLI, DT, AC)) + SI1->getFalseValue(), DL, &TLI, &DT, &AC)) SI = Builder->CreateSelect(SI0->getCondition(), Builder->CreateBinOp(TopLevelOpcode, SI0->getTrueValue(), SI1->getTrueValue()), V); if (Value *V = SimplifyBinOp(TopLevelOpcode, SI0->getTrueValue(), - SI1->getTrueValue(), DL, TLI, DT, AC)) + SI1->getTrueValue(), DL, &TLI, &DT, &AC)) SI = Builder->CreateSelect( SI0->getCondition(), V, Builder->CreateBinOp(TopLevelOpcode, SI0->getFalseValue(), @@ -741,17 +740,18 @@ Value *InstCombiner::dyn_castFNegVal(Value *V, bool IgnoreZeroSign) const { return nullptr; } -static Value *FoldOperationIntoSelectOperand(Instruction &I, Value *SO, +static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, InstCombiner *IC) { - if (CastInst *CI = dyn_cast<CastInst>(&I)) { - return IC->Builder->CreateCast(CI->getOpcode(), SO, I.getType()); - } + if (auto *Cast = dyn_cast<CastInst>(&I)) + return IC->Builder->CreateCast(Cast->getOpcode(), SO, I.getType()); + + assert(I.isBinaryOp() && "Unexpected opcode for select folding"); // Figure out if the constant is the left or the right argument. bool ConstIsRHS = isa<Constant>(I.getOperand(1)); Constant *ConstOperand = cast<Constant>(I.getOperand(ConstIsRHS)); - if (Constant *SOC = dyn_cast<Constant>(SO)) { + if (auto *SOC = dyn_cast<Constant>(SO)) { if (ConstIsRHS) return ConstantExpr::get(I.getOpcode(), SOC, ConstOperand); return ConstantExpr::get(I.getOpcode(), ConstOperand, SOC); @@ -761,78 +761,65 @@ static Value *FoldOperationIntoSelectOperand(Instruction &I, Value *SO, if (!ConstIsRHS) std::swap(Op0, Op1); - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(&I)) { - Value *RI = IC->Builder->CreateBinOp(BO->getOpcode(), Op0, Op1, - SO->getName()+".op"); - Instruction *FPInst = dyn_cast<Instruction>(RI); - if (FPInst && isa<FPMathOperator>(FPInst)) - FPInst->copyFastMathFlags(BO); - return RI; - } - if (ICmpInst *CI = dyn_cast<ICmpInst>(&I)) - return IC->Builder->CreateICmp(CI->getPredicate(), Op0, Op1, - SO->getName()+".cmp"); - if (FCmpInst *CI = dyn_cast<FCmpInst>(&I)) - return IC->Builder->CreateICmp(CI->getPredicate(), Op0, Op1, - SO->getName()+".cmp"); - llvm_unreachable("Unknown binary instruction type!"); + auto *BO = cast<BinaryOperator>(&I); + Value *RI = IC->Builder->CreateBinOp(BO->getOpcode(), Op0, Op1, + SO->getName() + ".op"); + auto *FPInst = dyn_cast<Instruction>(RI); + if (FPInst && isa<FPMathOperator>(FPInst)) + FPInst->copyFastMathFlags(BO); + return RI; } -/// Given an instruction with a select as one operand and a constant as the -/// other operand, try to fold the binary operator into the select arguments. -/// This also works for Cast instructions, which obviously do not have a second -/// operand. Instruction *InstCombiner::FoldOpIntoSelect(Instruction &Op, SelectInst *SI) { - // Don't modify shared select instructions - if (!SI->hasOneUse()) return nullptr; - Value *TV = SI->getOperand(1); - Value *FV = SI->getOperand(2); - - if (isa<Constant>(TV) || isa<Constant>(FV)) { - // Bool selects with constant operands can be folded to logical ops. - if (SI->getType()->isIntegerTy(1)) return nullptr; - - // If it's a bitcast involving vectors, make sure it has the same number of - // elements on both sides. - if (BitCastInst *BC = dyn_cast<BitCastInst>(&Op)) { - VectorType *DestTy = dyn_cast<VectorType>(BC->getDestTy()); - VectorType *SrcTy = dyn_cast<VectorType>(BC->getSrcTy()); - - // Verify that either both or neither are vectors. - if ((SrcTy == nullptr) != (DestTy == nullptr)) return nullptr; - // If vectors, verify that they have the same number of elements. - if (SrcTy && SrcTy->getNumElements() != DestTy->getNumElements()) - return nullptr; - } + // Don't modify shared select instructions. + if (!SI->hasOneUse()) + return nullptr; - // Test if a CmpInst instruction is used exclusively by a select as - // part of a minimum or maximum operation. If so, refrain from doing - // any other folding. This helps out other analyses which understand - // non-obfuscated minimum and maximum idioms, such as ScalarEvolution - // and CodeGen. And in this case, at least one of the comparison - // operands has at least one user besides the compare (the select), - // which would often largely negate the benefit of folding anyway. - if (auto *CI = dyn_cast<CmpInst>(SI->getCondition())) { - if (CI->hasOneUse()) { - Value *Op0 = CI->getOperand(0), *Op1 = CI->getOperand(1); - if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || - (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) - return nullptr; - } - } + Value *TV = SI->getTrueValue(); + Value *FV = SI->getFalseValue(); + if (!(isa<Constant>(TV) || isa<Constant>(FV))) + return nullptr; - Value *SelectTrueVal = FoldOperationIntoSelectOperand(Op, TV, this); - Value *SelectFalseVal = FoldOperationIntoSelectOperand(Op, FV, this); + // Bool selects with constant operands can be folded to logical ops. + if (SI->getType()->getScalarType()->isIntegerTy(1)) + return nullptr; - return SelectInst::Create(SI->getCondition(), - SelectTrueVal, SelectFalseVal); + // If it's a bitcast involving vectors, make sure it has the same number of + // elements on both sides. + if (auto *BC = dyn_cast<BitCastInst>(&Op)) { + VectorType *DestTy = dyn_cast<VectorType>(BC->getDestTy()); + VectorType *SrcTy = dyn_cast<VectorType>(BC->getSrcTy()); + + // Verify that either both or neither are vectors. + if ((SrcTy == nullptr) != (DestTy == nullptr)) + return nullptr; + + // If vectors, verify that they have the same number of elements. + if (SrcTy && SrcTy->getNumElements() != DestTy->getNumElements()) + return nullptr; } - return nullptr; + + // Test if a CmpInst instruction is used exclusively by a select as + // part of a minimum or maximum operation. If so, refrain from doing + // any other folding. This helps out other analyses which understand + // non-obfuscated minimum and maximum idioms, such as ScalarEvolution + // and CodeGen. And in this case, at least one of the comparison + // operands has at least one user besides the compare (the select), + // which would often largely negate the benefit of folding anyway. + if (auto *CI = dyn_cast<CmpInst>(SI->getCondition())) { + if (CI->hasOneUse()) { + Value *Op0 = CI->getOperand(0), *Op1 = CI->getOperand(1); + if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || + (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) + return nullptr; + } + } + + Value *NewTV = foldOperationIntoSelectOperand(Op, TV, this); + Value *NewFV = foldOperationIntoSelectOperand(Op, FV, this); + return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI); } -/// Given a binary operator, cast instruction, or select which has a PHI node as -/// operand #0, see if we can fold the instruction into the PHI (which is only -/// possible if all operands to the PHI are constants). Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { PHINode *PN = cast<PHINode>(I.getOperand(0)); unsigned NumPHIValues = PN->getNumIncomingValues(); @@ -877,7 +864,7 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { // If the incoming non-constant value is in I's block, we will remove one // instruction, but insert another equivalent one, leading to infinite // instcombine. - if (isPotentiallyReachable(I.getParent(), NonConstBB, DT, LI)) + if (isPotentiallyReachable(I.getParent(), NonConstBB, &DT, LI)) return nullptr; } @@ -970,6 +957,19 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { return replaceInstUsesWith(I, NewPN); } +Instruction *InstCombiner::foldOpWithConstantIntoOperand(Instruction &I) { + assert(isa<Constant>(I.getOperand(1)) && "Unexpected operand type"); + + if (auto *Sel = dyn_cast<SelectInst>(I.getOperand(0))) { + if (Instruction *NewSel = FoldOpIntoSelect(I, Sel)) + return NewSel; + } else if (isa<PHINode>(I.getOperand(0))) { + if (Instruction *NewPhi = FoldOpIntoPhi(I)) + return NewPhi; + } + return nullptr; +} + /// Given a pointer type and a constant offset, determine whether or not there /// is a sequence of GEP indices into the pointed type that will land us at the /// specified offset. If so, fill them into NewIndices and return the resultant @@ -1379,7 +1379,8 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end()); - if (Value *V = SimplifyGEPInst(GEP.getSourceElementType(), Ops, DL, TLI, DT, AC)) + if (Value *V = + SimplifyGEPInst(GEP.getSourceElementType(), Ops, DL, &TLI, &DT, &AC)) return replaceInstUsesWith(GEP, V); Value *PtrOp = GEP.getOperand(0); @@ -1394,7 +1395,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { for (User::op_iterator I = GEP.op_begin() + 1, E = GEP.op_end(); I != E; ++I, ++GTI) { // Skip indices into struct types. - if (isa<StructType>(*GTI)) + if (GTI.isStruct()) continue; // Index type should have the same width as IntPtr @@ -1551,7 +1552,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { bool EndsWithSequential = false; for (gep_type_iterator I = gep_type_begin(*Src), E = gep_type_end(*Src); I != E; ++I) - EndsWithSequential = !(*I)->isStructTy(); + EndsWithSequential = I.isSequential(); // Can we combine the two pointer arithmetics offsets? if (EndsWithSequential) { @@ -1860,7 +1861,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (!Offset) { // If the bitcast is of an allocation, and the allocation will be // converted to match the type of the cast, don't touch this. - if (isa<AllocaInst>(Operand) || isAllocationFn(Operand, TLI)) { + if (isa<AllocaInst>(Operand) || isAllocationFn(Operand, &TLI)) { // See if the bitcast simplifies, if so, don't nuke this GEP yet. if (Instruction *I = visitBitCast(*BCI)) { if (I != BCI) { @@ -1898,6 +1899,25 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { } } + if (!GEP.isInBounds()) { + unsigned PtrWidth = + DL.getPointerSizeInBits(PtrOp->getType()->getPointerAddressSpace()); + APInt BasePtrOffset(PtrWidth, 0); + Value *UnderlyingPtrOp = + PtrOp->stripAndAccumulateInBoundsConstantOffsets(DL, + BasePtrOffset); + if (auto *AI = dyn_cast<AllocaInst>(UnderlyingPtrOp)) { + if (GEP.accumulateConstantOffset(DL, BasePtrOffset) && + BasePtrOffset.isNonNegative()) { + APInt AllocSize(PtrWidth, DL.getTypeAllocSize(AI->getAllocatedType())); + if (BasePtrOffset.ule(AllocSize)) { + return GetElementPtrInst::CreateInBounds( + PtrOp, makeArrayRef(Ops).slice(1), GEP.getName()); + } + } + } + } + return nullptr; } @@ -1963,8 +1983,8 @@ isAllocSiteRemovable(Instruction *AI, SmallVectorImpl<WeakVH> &Users, MemIntrinsic *MI = cast<MemIntrinsic>(II); if (MI->isVolatile() || MI->getRawDest() != PI) return false; + LLVM_FALLTHROUGH; } - // fall through case Intrinsic::dbg_declare: case Intrinsic::dbg_value: case Intrinsic::invariant_start: @@ -2002,7 +2022,7 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { // to null and free calls, delete the calls and replace the comparisons with // true or false as appropriate. SmallVector<WeakVH, 64> Users; - if (isAllocSiteRemovable(&MI, Users, TLI)) { + if (isAllocSiteRemovable(&MI, Users, &TLI)) { for (unsigned i = 0, e = Users.size(); i != e; ++i) { // Lowering all @llvm.objectsize calls first because they may // use a bitcast/GEP of the alloca we are removing. @@ -2013,12 +2033,9 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { if (II->getIntrinsicID() == Intrinsic::objectsize) { - uint64_t Size; - if (!getObjectSize(II->getArgOperand(0), Size, DL, TLI)) { - ConstantInt *CI = cast<ConstantInt>(II->getArgOperand(1)); - Size = CI->isZero() ? -1ULL : 0; - } - replaceInstUsesWith(*I, ConstantInt::get(I->getType(), Size)); + ConstantInt *Result = lowerObjectSizeCall(II, DL, &TLI, + /*MustSucceed=*/true); + replaceInstUsesWith(*I, Result); eraseInstFromFunction(*I); Users[i] = nullptr; // Skip examining in the next loop. } @@ -2218,6 +2235,20 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { Value *Cond = SI.getCondition(); + Value *Op0; + ConstantInt *AddRHS; + if (match(Cond, m_Add(m_Value(Op0), m_ConstantInt(AddRHS)))) { + // Change 'switch (X+4) case 1:' into 'switch (X) case -3'. + for (SwitchInst::CaseIt CaseIter : SI.cases()) { + Constant *NewCase = ConstantExpr::getSub(CaseIter.getCaseValue(), AddRHS); + assert(isa<ConstantInt>(NewCase) && + "Result of expression should be constant"); + CaseIter.setValue(cast<ConstantInt>(NewCase)); + } + SI.setCondition(Op0); + return &SI; + } + unsigned BitWidth = cast<IntegerType>(Cond->getType())->getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); computeKnownBits(Cond, KnownZero, KnownOne, 0, &SI); @@ -2238,43 +2269,20 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { // Shrink the condition operand if the new type is smaller than the old type. // This may produce a non-standard type for the switch, but that's ok because // the backend should extend back to a legal type for the target. - bool TruncCond = false; if (NewWidth > 0 && NewWidth < BitWidth) { - TruncCond = true; IntegerType *Ty = IntegerType::get(SI.getContext(), NewWidth); Builder->SetInsertPoint(&SI); Value *NewCond = Builder->CreateTrunc(Cond, Ty, "trunc"); SI.setCondition(NewCond); - for (auto &C : SI.cases()) - static_cast<SwitchInst::CaseIt *>(&C)->setValue(ConstantInt::get( - SI.getContext(), C.getCaseValue()->getValue().trunc(NewWidth))); - } - - ConstantInt *AddRHS = nullptr; - if (match(Cond, m_Add(m_Value(), m_ConstantInt(AddRHS)))) { - Instruction *I = cast<Instruction>(Cond); - // Change 'switch (X+4) case 1:' into 'switch (X) case -3'. - for (SwitchInst::CaseIt i = SI.case_begin(), e = SI.case_end(); i != e; - ++i) { - ConstantInt *CaseVal = i.getCaseValue(); - Constant *LHS = CaseVal; - if (TruncCond) { - LHS = LeadingKnownZeros - ? ConstantExpr::getZExt(CaseVal, Cond->getType()) - : ConstantExpr::getSExt(CaseVal, Cond->getType()); - } - Constant *NewCaseVal = ConstantExpr::getSub(LHS, AddRHS); - assert(isa<ConstantInt>(NewCaseVal) && - "Result of expression should be constant"); - i.setValue(cast<ConstantInt>(NewCaseVal)); + for (SwitchInst::CaseIt CaseIter : SI.cases()) { + APInt TruncatedCase = CaseIter.getCaseValue()->getValue().trunc(NewWidth); + CaseIter.setValue(ConstantInt::get(SI.getContext(), TruncatedCase)); } - SI.setCondition(I->getOperand(0)); - Worklist.Add(I); return &SI; } - return TruncCond ? &SI : nullptr; + return nullptr; } Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { @@ -2284,7 +2292,7 @@ Instruction *InstCombiner::visitExtractValueInst(ExtractValueInst &EV) { return replaceInstUsesWith(EV, Agg); if (Value *V = - SimplifyExtractValueInst(Agg, EV.getIndices(), DL, TLI, DT, AC)) + SimplifyExtractValueInst(Agg, EV.getIndices(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(EV, V); if (InsertValueInst *IV = dyn_cast<InsertValueInst>(Agg)) { @@ -2560,7 +2568,7 @@ Instruction *InstCombiner::visitLandingPadInst(LandingPadInst &LI) { // remove it from the filter. An unexpected type handler may be // set up for a call site which throws an exception of the same // type caught. In order for the exception thrown by the unexpected - // handler to propogate correctly, the filter must be correctly + // handler to propagate correctly, the filter must be correctly // described for the call site. // // Example: @@ -2813,7 +2821,7 @@ bool InstCombiner::run() { if (I == nullptr) continue; // skip null values. // Check to see if we can DCE the instruction. - if (isInstructionTriviallyDead(I, TLI)) { + if (isInstructionTriviallyDead(I, &TLI)) { DEBUG(dbgs() << "IC: DCE: " << *I << '\n'); eraseInstFromFunction(*I); ++NumDeadInst; @@ -2824,13 +2832,13 @@ bool InstCombiner::run() { // Instruction isn't dead, see if we can constant propagate it. if (!I->use_empty() && (I->getNumOperands() == 0 || isa<Constant>(I->getOperand(0)))) { - if (Constant *C = ConstantFoldInstruction(I, DL, TLI)) { + if (Constant *C = ConstantFoldInstruction(I, DL, &TLI)) { DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *I << '\n'); // Add operands to the worklist. replaceInstUsesWith(*I, C); ++NumConstProp; - if (isInstructionTriviallyDead(I, TLI)) + if (isInstructionTriviallyDead(I, &TLI)) eraseInstFromFunction(*I); MadeIRChange = true; continue; @@ -2839,20 +2847,21 @@ bool InstCombiner::run() { // In general, it is possible for computeKnownBits to determine all bits in // a value even when the operands are not all constants. - if (ExpensiveCombines && !I->use_empty() && I->getType()->isIntegerTy()) { - unsigned BitWidth = I->getType()->getScalarSizeInBits(); + Type *Ty = I->getType(); + if (ExpensiveCombines && !I->use_empty() && Ty->isIntOrIntVectorTy()) { + unsigned BitWidth = Ty->getScalarSizeInBits(); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); computeKnownBits(I, KnownZero, KnownOne, /*Depth*/0, I); if ((KnownZero | KnownOne).isAllOnesValue()) { - Constant *C = ConstantInt::get(I->getContext(), KnownOne); + Constant *C = ConstantInt::get(Ty, KnownOne); DEBUG(dbgs() << "IC: ConstFold (all bits known) to: " << *C << " from: " << *I << '\n'); // Add operands to the worklist. replaceInstUsesWith(*I, C); ++NumConstProp; - if (isInstructionTriviallyDead(I, TLI)) + if (isInstructionTriviallyDead(I, &TLI)) eraseInstFromFunction(*I); MadeIRChange = true; continue; @@ -2883,7 +2892,7 @@ bool InstCombiner::run() { // If the user is one of our immediate successors, and if that successor // only has us as a predecessors (we'd have to split the critical edge // otherwise), we can keep going. - if (UserIsSuccessor && UserParent->getSinglePredecessor()) { + if (UserIsSuccessor && UserParent->getUniquePredecessor()) { // Okay, the CFG is simple enough, try to sink this instruction. if (TryToSinkInstruction(I, UserParent)) { DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); @@ -2941,14 +2950,12 @@ bool InstCombiner::run() { eraseInstFromFunction(*I); } else { -#ifndef NDEBUG DEBUG(dbgs() << "IC: Mod = " << OrigI << '\n' << " New = " << *I << '\n'); -#endif // If the instruction was modified, it's possible that it is now dead. // if so, remove it. - if (isInstructionTriviallyDead(I, TLI)) { + if (isInstructionTriviallyDead(I, &TLI)) { eraseInstFromFunction(*I); } else { Worklist.Add(I); @@ -2981,7 +2988,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, Worklist.push_back(BB); SmallVector<Instruction*, 128> InstrsForInstCombineWorklist; - DenseMap<ConstantExpr*, Constant*> FoldedConstants; + DenseMap<Constant *, Constant *> FoldedConstants; do { BB = Worklist.pop_back_val(); @@ -3017,17 +3024,17 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, // See if we can constant fold its operands. for (User::op_iterator i = Inst->op_begin(), e = Inst->op_end(); i != e; ++i) { - ConstantExpr *CE = dyn_cast<ConstantExpr>(i); - if (CE == nullptr) + if (!isa<ConstantVector>(i) && !isa<ConstantExpr>(i)) continue; - Constant *&FoldRes = FoldedConstants[CE]; + auto *C = cast<Constant>(i); + Constant *&FoldRes = FoldedConstants[C]; if (!FoldRes) - FoldRes = ConstantFoldConstantExpression(CE, DL, TLI); + FoldRes = ConstantFoldConstant(C, DL, TLI); if (!FoldRes) - FoldRes = CE; + FoldRes = C; - if (FoldRes != CE) { + if (FoldRes != C) { *i = FoldRes; MadeIRChange = true; } @@ -3120,8 +3127,15 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, /// Builder - This is an IRBuilder that automatically inserts new /// instructions into the worklist when they are created. - IRBuilder<TargetFolder, InstCombineIRInserter> Builder( - F.getContext(), TargetFolder(DL), InstCombineIRInserter(Worklist, &AC)); + IRBuilder<TargetFolder, IRBuilderCallbackInserter> Builder( + F.getContext(), TargetFolder(DL), + IRBuilderCallbackInserter([&Worklist, &AC](Instruction *I) { + Worklist.Add(I); + + using namespace llvm::PatternMatch; + if (match(I, m_Intrinsic<Intrinsic::assume>())) + AC.registerAssumption(cast<CallInst>(I)); + })); // Lower dbg.declare intrinsics otherwise their value may be clobbered // by instcombiner. @@ -3137,7 +3151,7 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, bool Changed = prepareICWorklistFromFunction(F, DL, &TLI, Worklist); InstCombiner IC(Worklist, &Builder, F.optForMinSize(), ExpensiveCombines, - AA, &AC, &TLI, &DT, DL, LI); + AA, AC, TLI, DT, DL, LI); Changed |= IC.run(); if (!Changed) @@ -3148,7 +3162,7 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, } PreservedAnalyses InstCombinePass::run(Function &F, - AnalysisManager<Function> &AM) { + FunctionAnalysisManager &AM) { auto &AC = AM.getResult<AssumptionAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); diff --git a/contrib/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/contrib/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp index 43d1b37..f5e9e7d 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -54,6 +54,9 @@ #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> +#include <iomanip> +#include <limits> +#include <sstream> #include <string> #include <system_error> @@ -64,8 +67,8 @@ using namespace llvm; static const uint64_t kDefaultShadowScale = 3; static const uint64_t kDefaultShadowOffset32 = 1ULL << 29; static const uint64_t kDefaultShadowOffset64 = 1ULL << 44; +static const uint64_t kDynamicShadowSentinel = ~(uint64_t)0; static const uint64_t kIOSShadowOffset32 = 1ULL << 30; -static const uint64_t kIOSShadowOffset64 = 0x120200000; static const uint64_t kIOSSimShadowOffset32 = 1ULL << 30; static const uint64_t kIOSSimShadowOffset64 = kDefaultShadowOffset64; static const uint64_t kSmallX86_64ShadowOffset = 0x7FFF8000; // < 2G. @@ -78,8 +81,8 @@ static const uint64_t kAArch64_ShadowOffset64 = 1ULL << 36; static const uint64_t kFreeBSD_ShadowOffset32 = 1ULL << 30; static const uint64_t kFreeBSD_ShadowOffset64 = 1ULL << 46; static const uint64_t kWindowsShadowOffset32 = 3ULL << 28; -// TODO(wwchrome): Experimental for asan Win64, may change. -static const uint64_t kWindowsShadowOffset64 = 0x1ULL << 45; // 32TB. +// The shadow memory space is dynamically allocated. +static const uint64_t kWindowsShadowOffset64 = kDynamicShadowSentinel; static const size_t kMinStackMallocSize = 1 << 6; // 64B static const size_t kMaxStackMallocSize = 1 << 16; // 64K @@ -111,6 +114,7 @@ static const char *const kAsanStackFreeNameTemplate = "__asan_stack_free_"; static const char *const kAsanGenPrefix = "__asan_gen_"; static const char *const kODRGenPrefix = "__odr_asan_gen_"; static const char *const kSanCovGenPrefix = "__sancov_gen_"; +static const char *const kAsanSetShadowPrefix = "__asan_set_shadow_"; static const char *const kAsanPoisonStackMemoryName = "__asan_poison_stack_memory"; static const char *const kAsanUnpoisonStackMemoryName = @@ -121,6 +125,9 @@ static const char *const kAsanGlobalsRegisteredFlagName = static const char *const kAsanOptionDetectUseAfterReturn = "__asan_option_detect_stack_use_after_return"; +static const char *const kAsanShadowMemoryDynamicAddress = + "__asan_shadow_memory_dynamic_address"; + static const char *const kAsanAllocaPoison = "__asan_alloca_poison"; static const char *const kAsanAllocasUnpoison = "__asan_allocas_unpoison"; @@ -153,6 +160,11 @@ static cl::opt<bool> ClAlwaysSlowPath( "asan-always-slow-path", cl::desc("use instrumentation with slow path for all accesses"), cl::Hidden, cl::init(false)); +static cl::opt<bool> ClForceDynamicShadow( + "asan-force-dynamic-shadow", + cl::desc("Load shadow address into a local variable for each function"), + cl::Hidden, cl::init(false)); + // This flag limits the number of instructions to be instrumented // in any given BB. Normally, this should be set to unlimited (INT_MAX), // but due to http://llvm.org/bugs/show_bug.cgi?id=12652 we temporary @@ -164,6 +176,11 @@ static cl::opt<int> ClMaxInsnsToInstrumentPerBB( // This flag may need to be replaced with -f[no]asan-stack. static cl::opt<bool> ClStack("asan-stack", cl::desc("Handle stack memory"), cl::Hidden, cl::init(true)); +static cl::opt<uint32_t> ClMaxInlinePoisoningSize( + "asan-max-inline-poisoning-size", + cl::desc( + "Inline shadow poisoning for blocks up to the given size in bytes."), + cl::Hidden, cl::init(64)); static cl::opt<bool> ClUseAfterReturn("asan-use-after-return", cl::desc("Check stack-use-after-return"), cl::Hidden, cl::init(true)); @@ -196,9 +213,10 @@ static cl::opt<std::string> ClMemoryAccessCallbackPrefix( "asan-memory-access-callback-prefix", cl::desc("Prefix for memory access callbacks"), cl::Hidden, cl::init("__asan_")); -static cl::opt<bool> ClInstrumentAllocas("asan-instrument-allocas", - cl::desc("instrument dynamic allocas"), - cl::Hidden, cl::init(true)); +static cl::opt<bool> + ClInstrumentDynamicAllocas("asan-instrument-dynamic-allocas", + cl::desc("instrument dynamic allocas"), + cl::Hidden, cl::init(true)); static cl::opt<bool> ClSkipPromotableAllocas( "asan-skip-promotable-allocas", cl::desc("Do not instrument promotable allocas"), cl::Hidden, @@ -250,7 +268,7 @@ static cl::opt<bool> cl::desc("Use linker features to support dead " "code stripping of globals " "(Mach-O only)"), - cl::Hidden, cl::init(false)); + cl::Hidden, cl::init(true)); // Debug flags. static cl::opt<int> ClDebug("asan-debug", cl::desc("debug"), cl::Hidden, @@ -261,7 +279,7 @@ static cl::opt<std::string> ClDebugFunc("asan-debug-func", cl::Hidden, cl::desc("Debug func")); static cl::opt<int> ClDebugMin("asan-debug-min", cl::desc("Debug min inst"), cl::Hidden, cl::init(-1)); -static cl::opt<int> ClDebugMax("asan-debug-max", cl::desc("Debug man inst"), +static cl::opt<int> ClDebugMax("asan-debug-max", cl::desc("Debug max inst"), cl::Hidden, cl::init(-1)); STATISTIC(NumInstrumentedReads, "Number of instrumented reads"); @@ -411,13 +429,19 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, Mapping.Offset = kMIPS64_ShadowOffset64; else if (IsIOS) // If we're targeting iOS and x86, the binary is built for iOS simulator. - Mapping.Offset = IsX86_64 ? kIOSSimShadowOffset64 : kIOSShadowOffset64; + // We are using dynamic shadow offset on the 64-bit devices. + Mapping.Offset = + IsX86_64 ? kIOSSimShadowOffset64 : kDynamicShadowSentinel; else if (IsAArch64) Mapping.Offset = kAArch64_ShadowOffset64; else Mapping.Offset = kDefaultShadowOffset64; } + if (ClForceDynamicShadow) { + Mapping.Offset = kDynamicShadowSentinel; + } + Mapping.Scale = kDefaultShadowScale; if (ClMappingScale.getNumOccurrences() > 0) { Mapping.Scale = ClMappingScale; @@ -433,7 +457,8 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, // we could OR the constant in a single instruction, but it's more // efficient to load it once and use indexed addressing. Mapping.OrShadowOffset = !IsAArch64 && !IsPPC64 && !IsSystemZ - && !(Mapping.Offset & (Mapping.Offset - 1)); + && !(Mapping.Offset & (Mapping.Offset - 1)) + && Mapping.Offset != kDynamicShadowSentinel; return Mapping; } @@ -450,42 +475,47 @@ struct AddressSanitizer : public FunctionPass { bool UseAfterScope = false) : FunctionPass(ID), CompileKernel(CompileKernel || ClEnableKasan), Recover(Recover || ClRecover), - UseAfterScope(UseAfterScope || ClUseAfterScope) { + UseAfterScope(UseAfterScope || ClUseAfterScope), + LocalDynamicShadow(nullptr) { initializeAddressSanitizerPass(*PassRegistry::getPassRegistry()); } - const char *getPassName() const override { + StringRef getPassName() const override { return "AddressSanitizerFunctionPass"; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); } - uint64_t getAllocaSizeInBytes(AllocaInst *AI) const { + uint64_t getAllocaSizeInBytes(const AllocaInst &AI) const { uint64_t ArraySize = 1; - if (AI->isArrayAllocation()) { - ConstantInt *CI = dyn_cast<ConstantInt>(AI->getArraySize()); + if (AI.isArrayAllocation()) { + const ConstantInt *CI = dyn_cast<ConstantInt>(AI.getArraySize()); assert(CI && "non-constant array size"); ArraySize = CI->getZExtValue(); } - Type *Ty = AI->getAllocatedType(); + Type *Ty = AI.getAllocatedType(); uint64_t SizeInBytes = - AI->getModule()->getDataLayout().getTypeAllocSize(Ty); + AI.getModule()->getDataLayout().getTypeAllocSize(Ty); return SizeInBytes * ArraySize; } /// Check if we want (and can) handle this alloca. - bool isInterestingAlloca(AllocaInst &AI); + bool isInterestingAlloca(const AllocaInst &AI); /// If it is an interesting memory access, return the PointerOperand /// and set IsWrite/Alignment. Otherwise return nullptr. + /// MaybeMask is an output parameter for the mask Value, if we're looking at a + /// masked load/store. Value *isInterestingMemoryAccess(Instruction *I, bool *IsWrite, - uint64_t *TypeSize, unsigned *Alignment); + uint64_t *TypeSize, unsigned *Alignment, + Value **MaybeMask = nullptr); void instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis, Instruction *I, bool UseCalls, const DataLayout &DL); void instrumentPointerComparisonOrSubtraction(Instruction *I); void instrumentAddress(Instruction *OrigIns, Instruction *InsertBefore, Value *Addr, uint32_t TypeSize, bool IsWrite, Value *SizeArgument, bool UseCalls, uint32_t Exp); - void instrumentUnusualSizeOrAlignment(Instruction *I, Value *Addr, + void instrumentUnusualSizeOrAlignment(Instruction *I, + Instruction *InsertBefore, Value *Addr, uint32_t TypeSize, bool IsWrite, Value *SizeArgument, bool UseCalls, uint32_t Exp); @@ -498,6 +528,7 @@ struct AddressSanitizer : public FunctionPass { Value *memToShadow(Value *Shadow, IRBuilder<> &IRB); bool runOnFunction(Function &F) override; bool maybeInsertAsanInitAtFunctionEntry(Function &F); + void maybeInsertDynamicShadowAtFunctionEntry(Function &F); void markEscapedLocalAllocas(Function &F); bool doInitialization(Module &M) override; bool doFinalization(Module &M) override; @@ -519,8 +550,12 @@ struct AddressSanitizer : public FunctionPass { FunctionStateRAII(AddressSanitizer *Pass) : Pass(Pass) { assert(Pass->ProcessedAllocas.empty() && "last pass forgot to clear cache"); + assert(!Pass->LocalDynamicShadow); + } + ~FunctionStateRAII() { + Pass->LocalDynamicShadow = nullptr; + Pass->ProcessedAllocas.clear(); } - ~FunctionStateRAII() { Pass->ProcessedAllocas.clear(); } }; LLVMContext *C; @@ -544,8 +579,9 @@ struct AddressSanitizer : public FunctionPass { Function *AsanMemoryAccessCallbackSized[2][2]; Function *AsanMemmove, *AsanMemcpy, *AsanMemset; InlineAsm *EmptyAsm; + Value *LocalDynamicShadow; GlobalsMetadata GlobalsMD; - DenseMap<AllocaInst *, bool> ProcessedAllocas; + DenseMap<const AllocaInst *, bool> ProcessedAllocas; friend struct FunctionStackPoisoner; }; @@ -558,14 +594,31 @@ class AddressSanitizerModule : public ModulePass { Recover(Recover || ClRecover) {} bool runOnModule(Module &M) override; static char ID; // Pass identification, replacement for typeid - const char *getPassName() const override { return "AddressSanitizerModule"; } + StringRef getPassName() const override { return "AddressSanitizerModule"; } - private: +private: void initializeCallbacks(Module &M); bool InstrumentGlobals(IRBuilder<> &IRB, Module &M); + void InstrumentGlobalsCOFF(IRBuilder<> &IRB, Module &M, + ArrayRef<GlobalVariable *> ExtendedGlobals, + ArrayRef<Constant *> MetadataInitializers); + void InstrumentGlobalsMachO(IRBuilder<> &IRB, Module &M, + ArrayRef<GlobalVariable *> ExtendedGlobals, + ArrayRef<Constant *> MetadataInitializers); + void + InstrumentGlobalsWithMetadataArray(IRBuilder<> &IRB, Module &M, + ArrayRef<GlobalVariable *> ExtendedGlobals, + ArrayRef<Constant *> MetadataInitializers); + + GlobalVariable *CreateMetadataGlobal(Module &M, Constant *Initializer, + StringRef OriginalName); + void SetComdatForGlobalMetadata(GlobalVariable *G, GlobalVariable *Metadata); + IRBuilder<> CreateAsanModuleDtor(Module &M); + bool ShouldInstrumentGlobal(GlobalVariable *G); bool ShouldUseMachOGlobalsSection() const; + StringRef getGlobalMetadataSection() const; void poisonOneInitializer(Function &GlobalInit, GlobalValue *ModuleName); void createInitializerPoisonCalls(Module &M, GlobalValue *ModuleName); size_t MinRedzoneSizeForGlobal() const { @@ -606,12 +659,13 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { ShadowMapping Mapping; SmallVector<AllocaInst *, 16> AllocaVec; - SmallSetVector<AllocaInst *, 16> NonInstrumentedStaticAllocaVec; + SmallVector<AllocaInst *, 16> StaticAllocasToMoveUp; SmallVector<Instruction *, 8> RetVec; unsigned StackAlignment; Function *AsanStackMallocFunc[kMaxAsanStackMallocSizeClass + 1], *AsanStackFreeFunc[kMaxAsanStackMallocSizeClass + 1]; + Function *AsanSetShadowFunc[0x100] = {}; Function *AsanPoisonStackMemoryFunc, *AsanUnpoisonStackMemoryFunc; Function *AsanAllocaPoisonFunc, *AsanAllocasUnpoisonFunc; @@ -622,7 +676,8 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { uint64_t Size; bool DoPoison; }; - SmallVector<AllocaPoisonCall, 8> AllocaPoisonCallVec; + SmallVector<AllocaPoisonCall, 8> DynamicAllocaPoisonCallVec; + SmallVector<AllocaPoisonCall, 8> StaticAllocaPoisonCallVec; SmallVector<AllocaInst *, 1> DynamicAllocaVec; SmallVector<IntrinsicInst *, 1> StackRestoreVec; @@ -657,7 +712,8 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { initializeCallbacks(*F.getParent()); - poisonStack(); + processDynamicAllocas(); + processStaticAllocas(); if (ClDebugStack) { DEBUG(dbgs() << F); @@ -668,7 +724,8 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { // Finds all Alloca instructions and puts // poisoned red zones around all of them. // Then unpoison everything back before the function returns. - void poisonStack(); + void processStaticAllocas(); + void processDynamicAllocas(); void createDynamicAllocasInitStorage(); @@ -676,6 +733,12 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { /// \brief Collect all Ret instructions. void visitReturnInst(ReturnInst &RI) { RetVec.push_back(&RI); } + /// \brief Collect all Resume instructions. + void visitResumeInst(ResumeInst &RI) { RetVec.push_back(&RI); } + + /// \brief Collect all CatchReturnInst instructions. + void visitCleanupReturnInst(CleanupReturnInst &CRI) { RetVec.push_back(&CRI); } + void unpoisonDynamicAllocasBeforeInst(Instruction *InstBefore, Value *SavedStack) { IRBuilder<> IRB(InstBefore); @@ -724,7 +787,14 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { /// \brief Collect Alloca instructions we want (and can) handle. void visitAllocaInst(AllocaInst &AI) { if (!ASan.isInterestingAlloca(AI)) { - if (AI.isStaticAlloca()) NonInstrumentedStaticAllocaVec.insert(&AI); + if (AI.isStaticAlloca()) { + // Skip over allocas that are present *before* the first instrumented + // alloca, we don't want to move those around. + if (AllocaVec.empty()) + return; + + StaticAllocasToMoveUp.push_back(&AI); + } return; } @@ -761,7 +831,10 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { return; bool DoPoison = (ID == Intrinsic::lifetime_end); AllocaPoisonCall APC = {&II, AI, SizeValue, DoPoison}; - AllocaPoisonCallVec.push_back(APC); + if (AI->isStaticAlloca()) + StaticAllocaPoisonCallVec.push_back(APC); + else if (ClInstrumentDynamicAllocas) + DynamicAllocaPoisonCallVec.push_back(APC); } void visitCallSite(CallSite CS) { @@ -785,12 +858,21 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { /// Finds alloca where the value comes from. AllocaInst *findAllocaForValue(Value *V); - void poisonRedZones(ArrayRef<uint8_t> ShadowBytes, IRBuilder<> &IRB, - Value *ShadowBase, bool DoPoison); + + // Copies bytes from ShadowBytes into shadow memory for indexes where + // ShadowMask is not zero. If ShadowMask[i] is zero, we assume that + // ShadowBytes[i] is constantly zero and doesn't need to be overwritten. + void copyToShadow(ArrayRef<uint8_t> ShadowMask, ArrayRef<uint8_t> ShadowBytes, + IRBuilder<> &IRB, Value *ShadowBase); + void copyToShadow(ArrayRef<uint8_t> ShadowMask, ArrayRef<uint8_t> ShadowBytes, + size_t Begin, size_t End, IRBuilder<> &IRB, + Value *ShadowBase); + void copyToShadowInline(ArrayRef<uint8_t> ShadowMask, + ArrayRef<uint8_t> ShadowBytes, size_t Begin, + size_t End, IRBuilder<> &IRB, Value *ShadowBase); + void poisonAlloca(Value *V, uint64_t Size, IRBuilder<> &IRB, bool DoPoison); - void SetShadowToStackAfterReturnInlined(IRBuilder<> &IRB, Value *ShadowBase, - int Size); Value *createAllocaForLayout(IRBuilder<> &IRB, const ASanStackFrameLayout &L, bool Dynamic); PHINode *createPHI(IRBuilder<> &IRB, Value *Cond, Value *ValueIfTrue, @@ -885,10 +967,15 @@ Value *AddressSanitizer::memToShadow(Value *Shadow, IRBuilder<> &IRB) { Shadow = IRB.CreateLShr(Shadow, Mapping.Scale); if (Mapping.Offset == 0) return Shadow; // (Shadow >> scale) | offset + Value *ShadowBase; + if (LocalDynamicShadow) + ShadowBase = LocalDynamicShadow; + else + ShadowBase = ConstantInt::get(IntptrTy, Mapping.Offset); if (Mapping.OrShadowOffset) - return IRB.CreateOr(Shadow, ConstantInt::get(IntptrTy, Mapping.Offset)); + return IRB.CreateOr(Shadow, ShadowBase); else - return IRB.CreateAdd(Shadow, ConstantInt::get(IntptrTy, Mapping.Offset)); + return IRB.CreateAdd(Shadow, ShadowBase); } // Instrument memset/memmove/memcpy @@ -911,7 +998,7 @@ void AddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) { } /// Check if we want (and can) handle this alloca. -bool AddressSanitizer::isInterestingAlloca(AllocaInst &AI) { +bool AddressSanitizer::isInterestingAlloca(const AllocaInst &AI) { auto PreviouslySeenAllocaInfo = ProcessedAllocas.find(&AI); if (PreviouslySeenAllocaInfo != ProcessedAllocas.end()) @@ -920,27 +1007,32 @@ bool AddressSanitizer::isInterestingAlloca(AllocaInst &AI) { bool IsInteresting = (AI.getAllocatedType()->isSized() && // alloca() may be called with 0 size, ignore it. - ((!AI.isStaticAlloca()) || getAllocaSizeInBytes(&AI) > 0) && + ((!AI.isStaticAlloca()) || getAllocaSizeInBytes(AI) > 0) && // We are only interested in allocas not promotable to registers. // Promotable allocas are common under -O0. (!ClSkipPromotableAllocas || !isAllocaPromotable(&AI)) && // inalloca allocas are not treated as static, and we don't want // dynamic alloca instrumentation for them as well. - !AI.isUsedWithInAlloca()); + !AI.isUsedWithInAlloca() && + // swifterror allocas are register promoted by ISel + !AI.isSwiftError()); ProcessedAllocas[&AI] = IsInteresting; return IsInteresting; } -/// If I is an interesting memory access, return the PointerOperand -/// and set IsWrite/Alignment. Otherwise return nullptr. Value *AddressSanitizer::isInterestingMemoryAccess(Instruction *I, bool *IsWrite, uint64_t *TypeSize, - unsigned *Alignment) { + unsigned *Alignment, + Value **MaybeMask) { // Skip memory accesses inserted by another instrumentation. if (I->getMetadata("nosanitize")) return nullptr; + // Do not instrument the load fetching the dynamic shadow address. + if (LocalDynamicShadow == I) + return nullptr; + Value *PtrOperand = nullptr; const DataLayout &DL = I->getModule()->getDataLayout(); if (LoadInst *LI = dyn_cast<LoadInst>(I)) { @@ -967,14 +1059,50 @@ Value *AddressSanitizer::isInterestingMemoryAccess(Instruction *I, *TypeSize = DL.getTypeStoreSizeInBits(XCHG->getCompareOperand()->getType()); *Alignment = 0; PtrOperand = XCHG->getPointerOperand(); + } else if (auto CI = dyn_cast<CallInst>(I)) { + auto *F = dyn_cast<Function>(CI->getCalledValue()); + if (F && (F->getName().startswith("llvm.masked.load.") || + F->getName().startswith("llvm.masked.store."))) { + unsigned OpOffset = 0; + if (F->getName().startswith("llvm.masked.store.")) { + if (!ClInstrumentWrites) + return nullptr; + // Masked store has an initial operand for the value. + OpOffset = 1; + *IsWrite = true; + } else { + if (!ClInstrumentReads) + return nullptr; + *IsWrite = false; + } + + auto BasePtr = CI->getOperand(0 + OpOffset); + auto Ty = cast<PointerType>(BasePtr->getType())->getElementType(); + *TypeSize = DL.getTypeStoreSizeInBits(Ty); + if (auto AlignmentConstant = + dyn_cast<ConstantInt>(CI->getOperand(1 + OpOffset))) + *Alignment = (unsigned)AlignmentConstant->getZExtValue(); + else + *Alignment = 1; // No alignment guarantees. We probably got Undef + if (MaybeMask) + *MaybeMask = CI->getOperand(2 + OpOffset); + PtrOperand = BasePtr; + } } - // Do not instrument acesses from different address spaces; we cannot deal - // with them. if (PtrOperand) { + // Do not instrument acesses from different address spaces; we cannot deal + // with them. Type *PtrTy = cast<PointerType>(PtrOperand->getType()->getScalarType()); if (PtrTy->getPointerAddressSpace() != 0) return nullptr; + + // Ignore swifterror addresses. + // swifterror memory addresses are mem2reg promoted by instruction + // selection. As such they cannot have regular uses like an instrumentation + // function and it makes no sense to track them as memory. + if (PtrOperand->isSwiftError()) + return nullptr; } // Treat memory accesses to promotable allocas as non-interesting since they @@ -1025,13 +1153,71 @@ void AddressSanitizer::instrumentPointerComparisonOrSubtraction( IRB.CreateCall(F, Param); } +static void doInstrumentAddress(AddressSanitizer *Pass, Instruction *I, + Instruction *InsertBefore, Value *Addr, + unsigned Alignment, unsigned Granularity, + uint32_t TypeSize, bool IsWrite, + Value *SizeArgument, bool UseCalls, + uint32_t Exp) { + // Instrument a 1-, 2-, 4-, 8-, or 16- byte access with one check + // if the data is properly aligned. + if ((TypeSize == 8 || TypeSize == 16 || TypeSize == 32 || TypeSize == 64 || + TypeSize == 128) && + (Alignment >= Granularity || Alignment == 0 || Alignment >= TypeSize / 8)) + return Pass->instrumentAddress(I, InsertBefore, Addr, TypeSize, IsWrite, + nullptr, UseCalls, Exp); + Pass->instrumentUnusualSizeOrAlignment(I, InsertBefore, Addr, TypeSize, + IsWrite, nullptr, UseCalls, Exp); +} + +static void instrumentMaskedLoadOrStore(AddressSanitizer *Pass, + const DataLayout &DL, Type *IntptrTy, + Value *Mask, Instruction *I, + Value *Addr, unsigned Alignment, + unsigned Granularity, uint32_t TypeSize, + bool IsWrite, Value *SizeArgument, + bool UseCalls, uint32_t Exp) { + auto *VTy = cast<PointerType>(Addr->getType())->getElementType(); + uint64_t ElemTypeSize = DL.getTypeStoreSizeInBits(VTy->getScalarType()); + unsigned Num = VTy->getVectorNumElements(); + auto Zero = ConstantInt::get(IntptrTy, 0); + for (unsigned Idx = 0; Idx < Num; ++Idx) { + Value *InstrumentedAddress = nullptr; + Instruction *InsertBefore = I; + if (auto *Vector = dyn_cast<ConstantVector>(Mask)) { + // dyn_cast as we might get UndefValue + if (auto *Masked = dyn_cast<ConstantInt>(Vector->getOperand(Idx))) { + if (Masked->isNullValue()) + // Mask is constant false, so no instrumentation needed. + continue; + // If we have a true or undef value, fall through to doInstrumentAddress + // with InsertBefore == I + } + } else { + IRBuilder<> IRB(I); + Value *MaskElem = IRB.CreateExtractElement(Mask, Idx); + TerminatorInst *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, I, false); + InsertBefore = ThenTerm; + } + + IRBuilder<> IRB(InsertBefore); + InstrumentedAddress = + IRB.CreateGEP(Addr, {Zero, ConstantInt::get(IntptrTy, Idx)}); + doInstrumentAddress(Pass, I, InsertBefore, InstrumentedAddress, Alignment, + Granularity, ElemTypeSize, IsWrite, SizeArgument, + UseCalls, Exp); + } +} + void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis, Instruction *I, bool UseCalls, const DataLayout &DL) { bool IsWrite = false; unsigned Alignment = 0; uint64_t TypeSize = 0; - Value *Addr = isInterestingMemoryAccess(I, &IsWrite, &TypeSize, &Alignment); + Value *MaybeMask = nullptr; + Value *Addr = + isInterestingMemoryAccess(I, &IsWrite, &TypeSize, &Alignment, &MaybeMask); assert(Addr); // Optimization experiments. @@ -1073,15 +1259,14 @@ void AddressSanitizer::instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis, NumInstrumentedReads++; unsigned Granularity = 1 << Mapping.Scale; - // Instrument a 1-, 2-, 4-, 8-, or 16- byte access with one check - // if the data is properly aligned. - if ((TypeSize == 8 || TypeSize == 16 || TypeSize == 32 || TypeSize == 64 || - TypeSize == 128) && - (Alignment >= Granularity || Alignment == 0 || Alignment >= TypeSize / 8)) - return instrumentAddress(I, I, Addr, TypeSize, IsWrite, nullptr, UseCalls, - Exp); - instrumentUnusualSizeOrAlignment(I, Addr, TypeSize, IsWrite, nullptr, - UseCalls, Exp); + if (MaybeMask) { + instrumentMaskedLoadOrStore(this, DL, IntptrTy, MaybeMask, I, Addr, + Alignment, Granularity, TypeSize, IsWrite, + nullptr, UseCalls, Exp); + } else { + doInstrumentAddress(this, I, I, Addr, Alignment, Granularity, TypeSize, + IsWrite, nullptr, UseCalls, Exp); + } } Instruction *AddressSanitizer::generateCrashCode(Instruction *InsertBefore, @@ -1196,9 +1381,9 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, // and the last bytes. We call __asan_report_*_n(addr, real_size) to be able // to report the actual access size. void AddressSanitizer::instrumentUnusualSizeOrAlignment( - Instruction *I, Value *Addr, uint32_t TypeSize, bool IsWrite, - Value *SizeArgument, bool UseCalls, uint32_t Exp) { - IRBuilder<> IRB(I); + Instruction *I, Instruction *InsertBefore, Value *Addr, uint32_t TypeSize, + bool IsWrite, Value *SizeArgument, bool UseCalls, uint32_t Exp) { + IRBuilder<> IRB(InsertBefore); Value *Size = ConstantInt::get(IntptrTy, TypeSize / 8); Value *AddrLong = IRB.CreatePointerCast(Addr, IntptrTy); if (UseCalls) { @@ -1212,8 +1397,8 @@ void AddressSanitizer::instrumentUnusualSizeOrAlignment( Value *LastByte = IRB.CreateIntToPtr( IRB.CreateAdd(AddrLong, ConstantInt::get(IntptrTy, TypeSize / 8 - 1)), Addr->getType()); - instrumentAddress(I, I, Addr, 8, IsWrite, Size, false, Exp); - instrumentAddress(I, I, LastByte, 8, IsWrite, Size, false, Exp); + instrumentAddress(I, InsertBefore, Addr, 8, IsWrite, Size, false, Exp); + instrumentAddress(I, InsertBefore, LastByte, 8, IsWrite, Size, false, Exp); } } @@ -1361,6 +1546,16 @@ bool AddressSanitizerModule::ShouldUseMachOGlobalsSection() const { return false; } +StringRef AddressSanitizerModule::getGlobalMetadataSection() const { + switch (TargetTriple.getObjectFormat()) { + case Triple::COFF: return ".ASAN$GL"; + case Triple::ELF: return "asan_globals"; + case Triple::MachO: return "__DATA,__asan_globals,regular"; + default: break; + } + llvm_unreachable("unsupported object format"); +} + void AddressSanitizerModule::initializeCallbacks(Module &M) { IRBuilder<> IRB(*C); @@ -1383,17 +1578,173 @@ void AddressSanitizerModule::initializeCallbacks(Module &M) { // Declare the functions that find globals in a shared object and then invoke // the (un)register function on them. - AsanRegisterImageGlobals = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(kAsanRegisterImageGlobalsName, - IRB.getVoidTy(), IntptrTy, nullptr)); + AsanRegisterImageGlobals = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + kAsanRegisterImageGlobalsName, IRB.getVoidTy(), IntptrTy, nullptr)); AsanRegisterImageGlobals->setLinkage(Function::ExternalLinkage); - AsanUnregisterImageGlobals = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(kAsanUnregisterImageGlobalsName, - IRB.getVoidTy(), IntptrTy, nullptr)); + AsanUnregisterImageGlobals = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + kAsanUnregisterImageGlobalsName, IRB.getVoidTy(), IntptrTy, nullptr)); AsanUnregisterImageGlobals->setLinkage(Function::ExternalLinkage); } +// Put the metadata and the instrumented global in the same group. This ensures +// that the metadata is discarded if the instrumented global is discarded. +void AddressSanitizerModule::SetComdatForGlobalMetadata( + GlobalVariable *G, GlobalVariable *Metadata) { + Module &M = *G->getParent(); + Comdat *C = G->getComdat(); + if (!C) { + if (!G->hasName()) { + // If G is unnamed, it must be internal. Give it an artificial name + // so we can put it in a comdat. + assert(G->hasLocalLinkage()); + G->setName(Twine(kAsanGenPrefix) + "_anon_global"); + } + C = M.getOrInsertComdat(G->getName()); + // Make this IMAGE_COMDAT_SELECT_NODUPLICATES on COFF. + if (TargetTriple.isOSBinFormatCOFF()) + C->setSelectionKind(Comdat::NoDuplicates); + G->setComdat(C); + } + + assert(G->hasComdat()); + Metadata->setComdat(G->getComdat()); +} + +// Create a separate metadata global and put it in the appropriate ASan +// global registration section. +GlobalVariable * +AddressSanitizerModule::CreateMetadataGlobal(Module &M, Constant *Initializer, + StringRef OriginalName) { + GlobalVariable *Metadata = + new GlobalVariable(M, Initializer->getType(), false, + GlobalVariable::InternalLinkage, Initializer, + Twine("__asan_global_") + + GlobalValue::getRealLinkageName(OriginalName)); + Metadata->setSection(getGlobalMetadataSection()); + return Metadata; +} + +IRBuilder<> AddressSanitizerModule::CreateAsanModuleDtor(Module &M) { + Function *AsanDtorFunction = + Function::Create(FunctionType::get(Type::getVoidTy(*C), false), + GlobalValue::InternalLinkage, kAsanModuleDtorName, &M); + BasicBlock *AsanDtorBB = BasicBlock::Create(*C, "", AsanDtorFunction); + appendToGlobalDtors(M, AsanDtorFunction, kAsanCtorAndDtorPriority); + + return IRBuilder<>(ReturnInst::Create(*C, AsanDtorBB)); +} + +void AddressSanitizerModule::InstrumentGlobalsCOFF( + IRBuilder<> &IRB, Module &M, ArrayRef<GlobalVariable *> ExtendedGlobals, + ArrayRef<Constant *> MetadataInitializers) { + assert(ExtendedGlobals.size() == MetadataInitializers.size()); + auto &DL = M.getDataLayout(); + + for (size_t i = 0; i < ExtendedGlobals.size(); i++) { + Constant *Initializer = MetadataInitializers[i]; + GlobalVariable *G = ExtendedGlobals[i]; + GlobalVariable *Metadata = + CreateMetadataGlobal(M, Initializer, G->getName()); + + // The MSVC linker always inserts padding when linking incrementally. We + // cope with that by aligning each struct to its size, which must be a power + // of two. + unsigned SizeOfGlobalStruct = DL.getTypeAllocSize(Initializer->getType()); + assert(isPowerOf2_32(SizeOfGlobalStruct) && + "global metadata will not be padded appropriately"); + Metadata->setAlignment(SizeOfGlobalStruct); + + SetComdatForGlobalMetadata(G, Metadata); + } +} + +void AddressSanitizerModule::InstrumentGlobalsMachO( + IRBuilder<> &IRB, Module &M, ArrayRef<GlobalVariable *> ExtendedGlobals, + ArrayRef<Constant *> MetadataInitializers) { + assert(ExtendedGlobals.size() == MetadataInitializers.size()); + + // On recent Mach-O platforms, use a structure which binds the liveness of + // the global variable to the metadata struct. Keep the list of "Liveness" GV + // created to be added to llvm.compiler.used + StructType *LivenessTy = StructType::get(IntptrTy, IntptrTy, nullptr); + SmallVector<GlobalValue *, 16> LivenessGlobals(ExtendedGlobals.size()); + + for (size_t i = 0; i < ExtendedGlobals.size(); i++) { + Constant *Initializer = MetadataInitializers[i]; + GlobalVariable *G = ExtendedGlobals[i]; + GlobalVariable *Metadata = + CreateMetadataGlobal(M, Initializer, G->getName()); + + // On recent Mach-O platforms, we emit the global metadata in a way that + // allows the linker to properly strip dead globals. + auto LivenessBinder = ConstantStruct::get( + LivenessTy, Initializer->getAggregateElement(0u), + ConstantExpr::getPointerCast(Metadata, IntptrTy), nullptr); + GlobalVariable *Liveness = new GlobalVariable( + M, LivenessTy, false, GlobalVariable::InternalLinkage, LivenessBinder, + Twine("__asan_binder_") + G->getName()); + Liveness->setSection("__DATA,__asan_liveness,regular,live_support"); + LivenessGlobals[i] = Liveness; + } + + // Update llvm.compiler.used, adding the new liveness globals. This is + // needed so that during LTO these variables stay alive. The alternative + // would be to have the linker handling the LTO symbols, but libLTO + // current API does not expose access to the section for each symbol. + if (!LivenessGlobals.empty()) + appendToCompilerUsed(M, LivenessGlobals); + + // RegisteredFlag serves two purposes. First, we can pass it to dladdr() + // to look up the loaded image that contains it. Second, we can store in it + // whether registration has already occurred, to prevent duplicate + // registration. + // + // common linkage ensures that there is only one global per shared library. + GlobalVariable *RegisteredFlag = new GlobalVariable( + M, IntptrTy, false, GlobalVariable::CommonLinkage, + ConstantInt::get(IntptrTy, 0), kAsanGlobalsRegisteredFlagName); + RegisteredFlag->setVisibility(GlobalVariable::HiddenVisibility); + + IRB.CreateCall(AsanRegisterImageGlobals, + {IRB.CreatePointerCast(RegisteredFlag, IntptrTy)}); + + // We also need to unregister globals at the end, e.g., when a shared library + // gets closed. + IRBuilder<> IRB_Dtor = CreateAsanModuleDtor(M); + IRB_Dtor.CreateCall(AsanUnregisterImageGlobals, + {IRB.CreatePointerCast(RegisteredFlag, IntptrTy)}); +} + +void AddressSanitizerModule::InstrumentGlobalsWithMetadataArray( + IRBuilder<> &IRB, Module &M, ArrayRef<GlobalVariable *> ExtendedGlobals, + ArrayRef<Constant *> MetadataInitializers) { + assert(ExtendedGlobals.size() == MetadataInitializers.size()); + unsigned N = ExtendedGlobals.size(); + assert(N > 0); + + // On platforms that don't have a custom metadata section, we emit an array + // of global metadata structures. + ArrayType *ArrayOfGlobalStructTy = + ArrayType::get(MetadataInitializers[0]->getType(), N); + auto AllGlobals = new GlobalVariable( + M, ArrayOfGlobalStructTy, false, GlobalVariable::InternalLinkage, + ConstantArray::get(ArrayOfGlobalStructTy, MetadataInitializers), ""); + + IRB.CreateCall(AsanRegisterGlobals, + {IRB.CreatePointerCast(AllGlobals, IntptrTy), + ConstantInt::get(IntptrTy, N)}); + + // We also need to unregister globals at the end, e.g., when a shared library + // gets closed. + IRBuilder<> IRB_Dtor = CreateAsanModuleDtor(M); + IRB_Dtor.CreateCall(AsanUnregisterGlobals, + {IRB.CreatePointerCast(AllGlobals, IntptrTy), + ConstantInt::get(IntptrTy, N)}); +} + // This function replaces all global variables with new variables that have // trailing redzones. It also creates a function that poisons // redzones and inserts this function into llvm.global_ctors. @@ -1409,6 +1760,8 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { size_t n = GlobalsToChange.size(); if (n == 0) return false; + auto &DL = M.getDataLayout(); + // A global is described by a structure // size_t beg; // size_t size; @@ -1422,6 +1775,7 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { StructType *GlobalStructTy = StructType::get(IntptrTy, IntptrTy, IntptrTy, IntptrTy, IntptrTy, IntptrTy, IntptrTy, IntptrTy, nullptr); + SmallVector<GlobalVariable *, 16> NewGlobals(n); SmallVector<Constant *, 16> Initializers(n); bool HasDynamicallyInitializedGlobals = false; @@ -1431,7 +1785,6 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { GlobalVariable *ModuleName = createPrivateGlobalForString( M, M.getModuleIdentifier(), /*AllowMerging*/ false); - auto &DL = M.getDataLayout(); for (size_t i = 0; i < n; i++) { static const uint64_t kMaxGlobalRedzone = 1 << 18; GlobalVariable *G = GlobalsToChange[i]; @@ -1472,6 +1825,21 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { NewGlobal->copyAttributesFrom(G); NewGlobal->setAlignment(MinRZ); + // Move null-terminated C strings to "__asan_cstring" section on Darwin. + if (TargetTriple.isOSBinFormatMachO() && !G->hasSection() && + G->isConstant()) { + auto Seq = dyn_cast<ConstantDataSequential>(G->getInitializer()); + if (Seq && Seq->isCString()) + NewGlobal->setSection("__TEXT,__asan_cstring,regular"); + } + + // Transfer the debug info. The payload starts at offset zero so we can + // copy the debug info over as is. + SmallVector<DIGlobalVariableExpression *, 1> GVs; + G->getDebugInfo(GVs); + for (auto *GV : GVs) + NewGlobal->addDebugInfo(GV); + Value *Indices2[2]; Indices2[0] = IRB.getInt32(0); Indices2[1] = IRB.getInt32(0); @@ -1480,6 +1848,7 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { ConstantExpr::getGetElementPtr(NewTy, NewGlobal, Indices2, true)); NewGlobal->takeName(G); G->eraseFromParent(); + NewGlobals[i] = NewGlobal; Constant *SourceLoc; if (!MD.SourceLoc.empty()) { @@ -1492,7 +1861,8 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { Constant *ODRIndicator = ConstantExpr::getNullValue(IRB.getInt8PtrTy()); GlobalValue *InstrumentedGlobal = NewGlobal; - bool CanUsePrivateAliases = TargetTriple.isOSBinFormatELF(); + bool CanUsePrivateAliases = + TargetTriple.isOSBinFormatELF() || TargetTriple.isOSBinFormatMachO(); if (CanUsePrivateAliases && ClUsePrivateAliasForGlobals) { // Create local alias for NewGlobal to avoid crash on ODR between // instrumented and non-instrumented libraries. @@ -1515,7 +1885,7 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { InstrumentedGlobal = GA; } - Initializers[i] = ConstantStruct::get( + Constant *Initializer = ConstantStruct::get( GlobalStructTy, ConstantExpr::getPointerCast(InstrumentedGlobal, IntptrTy), ConstantInt::get(IntptrTy, SizeInBytes), @@ -1528,88 +1898,22 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M) { if (ClInitializers && MD.IsDynInit) HasDynamicallyInitializedGlobals = true; DEBUG(dbgs() << "NEW GLOBAL: " << *NewGlobal << "\n"); - } + Initializers[i] = Initializer; + } - GlobalVariable *AllGlobals = nullptr; - GlobalVariable *RegisteredFlag = nullptr; - - // On recent Mach-O platforms, we emit the global metadata in a way that - // allows the linker to properly strip dead globals. - if (ShouldUseMachOGlobalsSection()) { - // RegisteredFlag serves two purposes. First, we can pass it to dladdr() - // to look up the loaded image that contains it. Second, we can store in it - // whether registration has already occurred, to prevent duplicate - // registration. - // - // Common linkage allows us to coalesce needles defined in each object - // file so that there's only one per shared library. - RegisteredFlag = new GlobalVariable( - M, IntptrTy, false, GlobalVariable::CommonLinkage, - ConstantInt::get(IntptrTy, 0), kAsanGlobalsRegisteredFlagName); - - // We also emit a structure which binds the liveness of the global - // variable to the metadata struct. - StructType *LivenessTy = StructType::get(IntptrTy, IntptrTy, nullptr); - - for (size_t i = 0; i < n; i++) { - GlobalVariable *Metadata = new GlobalVariable( - M, GlobalStructTy, false, GlobalVariable::InternalLinkage, - Initializers[i], ""); - Metadata->setSection("__DATA,__asan_globals,regular"); - Metadata->setAlignment(1); // don't leave padding in between - - auto LivenessBinder = ConstantStruct::get(LivenessTy, - Initializers[i]->getAggregateElement(0u), - ConstantExpr::getPointerCast(Metadata, IntptrTy), - nullptr); - GlobalVariable *Liveness = new GlobalVariable( - M, LivenessTy, false, GlobalVariable::InternalLinkage, - LivenessBinder, ""); - Liveness->setSection("__DATA,__asan_liveness,regular,live_support"); - } + if (TargetTriple.isOSBinFormatCOFF()) { + InstrumentGlobalsCOFF(IRB, M, NewGlobals, Initializers); + } else if (ShouldUseMachOGlobalsSection()) { + InstrumentGlobalsMachO(IRB, M, NewGlobals, Initializers); } else { - // On all other platfoms, we just emit an array of global metadata - // structures. - ArrayType *ArrayOfGlobalStructTy = ArrayType::get(GlobalStructTy, n); - AllGlobals = new GlobalVariable( - M, ArrayOfGlobalStructTy, false, GlobalVariable::InternalLinkage, - ConstantArray::get(ArrayOfGlobalStructTy, Initializers), ""); + InstrumentGlobalsWithMetadataArray(IRB, M, NewGlobals, Initializers); } // Create calls for poisoning before initializers run and unpoisoning after. if (HasDynamicallyInitializedGlobals) createInitializerPoisonCalls(M, ModuleName); - // Create a call to register the globals with the runtime. - if (ShouldUseMachOGlobalsSection()) { - IRB.CreateCall(AsanRegisterImageGlobals, - {IRB.CreatePointerCast(RegisteredFlag, IntptrTy)}); - } else { - IRB.CreateCall(AsanRegisterGlobals, - {IRB.CreatePointerCast(AllGlobals, IntptrTy), - ConstantInt::get(IntptrTy, n)}); - } - - // We also need to unregister globals at the end, e.g., when a shared library - // gets closed. - Function *AsanDtorFunction = - Function::Create(FunctionType::get(Type::getVoidTy(*C), false), - GlobalValue::InternalLinkage, kAsanModuleDtorName, &M); - BasicBlock *AsanDtorBB = BasicBlock::Create(*C, "", AsanDtorFunction); - IRBuilder<> IRB_Dtor(ReturnInst::Create(*C, AsanDtorBB)); - - if (ShouldUseMachOGlobalsSection()) { - IRB_Dtor.CreateCall(AsanUnregisterImageGlobals, - {IRB.CreatePointerCast(RegisteredFlag, IntptrTy)}); - } else { - IRB_Dtor.CreateCall(AsanUnregisterGlobals, - {IRB.CreatePointerCast(AllGlobals, IntptrTy), - ConstantInt::get(IntptrTy, n)}); - } - - appendToGlobalDtors(M, AsanDtorFunction, kAsanCtorAndDtorPriority); - DEBUG(dbgs() << M); return true; } @@ -1737,6 +2041,17 @@ bool AddressSanitizer::maybeInsertAsanInitAtFunctionEntry(Function &F) { return false; } +void AddressSanitizer::maybeInsertDynamicShadowAtFunctionEntry(Function &F) { + // Generate code only when dynamic addressing is needed. + if (Mapping.Offset != kDynamicShadowSentinel) + return; + + IRBuilder<> IRB(&F.front().front()); + Value *GlobalDynamicAddress = F.getParent()->getOrInsertGlobal( + kAsanShadowMemoryDynamicAddress, IntptrTy); + LocalDynamicShadow = IRB.CreateLoad(GlobalDynamicAddress); +} + void AddressSanitizer::markEscapedLocalAllocas(Function &F) { // Find the one possible call to llvm.localescape and pre-mark allocas passed // to it as uninteresting. This assumes we haven't started processing allocas @@ -1768,20 +2083,29 @@ void AddressSanitizer::markEscapedLocalAllocas(Function &F) { bool AddressSanitizer::runOnFunction(Function &F) { if (&F == AsanCtorFunction) return false; if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) return false; - DEBUG(dbgs() << "ASAN instrumenting:\n" << F << "\n"); - initializeCallbacks(*F.getParent()); + if (!ClDebugFunc.empty() && ClDebugFunc == F.getName()) return false; + if (F.getName().startswith("__asan_")) return false; - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + bool FunctionModified = false; // If needed, insert __asan_init before checking for SanitizeAddress attr. - maybeInsertAsanInitAtFunctionEntry(F); + // This function needs to be called even if the function body is not + // instrumented. + if (maybeInsertAsanInitAtFunctionEntry(F)) + FunctionModified = true; + + // Leave if the function doesn't need instrumentation. + if (!F.hasFnAttribute(Attribute::SanitizeAddress)) return FunctionModified; - if (!F.hasFnAttribute(Attribute::SanitizeAddress)) return false; + DEBUG(dbgs() << "ASAN instrumenting:\n" << F << "\n"); - if (!ClDebugFunc.empty() && ClDebugFunc != F.getName()) return false; + initializeCallbacks(*F.getParent()); + DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); FunctionStateRAII CleanupObj(this); + maybeInsertDynamicShadowAtFunctionEntry(F); + // We can't instrument allocas used with llvm.localescape. Only static allocas // can be passed to that intrinsic. markEscapedLocalAllocas(F); @@ -1807,11 +2131,20 @@ bool AddressSanitizer::runOnFunction(Function &F) { int NumInsnsPerBB = 0; for (auto &Inst : BB) { if (LooksLikeCodeInBug11395(&Inst)) return false; + Value *MaybeMask = nullptr; if (Value *Addr = isInterestingMemoryAccess(&Inst, &IsWrite, &TypeSize, - &Alignment)) { + &Alignment, &MaybeMask)) { if (ClOpt && ClOptSameTemp) { - if (!TempsToInstrument.insert(Addr).second) - continue; // We've seen this temp in the current BB. + // If we have a mask, skip instrumentation if we've already + // instrumented the full object. But don't add to TempsToInstrument + // because we might get another load/store with a different mask. + if (MaybeMask) { + if (TempsToInstrument.count(Addr)) + continue; // We've seen this (whole) temp in the current BB. + } else { + if (!TempsToInstrument.insert(Addr).second) + continue; // We've seen this temp in the current BB. + } } } else if (ClInvalidPointerPairs && isInterestingPointerComparisonOrSubtraction(&Inst)) { @@ -1874,11 +2207,13 @@ bool AddressSanitizer::runOnFunction(Function &F) { NumInstrumented++; } - bool res = NumInstrumented > 0 || ChangedStack || !NoReturnCalls.empty(); + if (NumInstrumented > 0 || ChangedStack || !NoReturnCalls.empty()) + FunctionModified = true; - DEBUG(dbgs() << "ASAN done instrumenting: " << res << " " << F << "\n"); + DEBUG(dbgs() << "ASAN done instrumenting: " << FunctionModified << " " + << F << "\n"); - return res; + return FunctionModified; } // Workaround for bug 11395: we don't want to instrument stack in functions @@ -1913,6 +2248,15 @@ void FunctionStackPoisoner::initializeCallbacks(Module &M) { IntptrTy, IntptrTy, nullptr)); } + for (size_t Val : {0x00, 0xf1, 0xf2, 0xf3, 0xf5, 0xf8}) { + std::ostringstream Name; + Name << kAsanSetShadowPrefix; + Name << std::setw(2) << std::setfill('0') << std::hex << Val; + AsanSetShadowFunc[Val] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + Name.str(), IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); + } + AsanAllocaPoisonFunc = checkSanitizerInterfaceFunction(M.getOrInsertFunction( kAsanAllocaPoison, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); AsanAllocasUnpoisonFunc = @@ -1920,31 +2264,93 @@ void FunctionStackPoisoner::initializeCallbacks(Module &M) { kAsanAllocasUnpoison, IRB.getVoidTy(), IntptrTy, IntptrTy, nullptr)); } -void FunctionStackPoisoner::poisonRedZones(ArrayRef<uint8_t> ShadowBytes, - IRBuilder<> &IRB, Value *ShadowBase, - bool DoPoison) { - size_t n = ShadowBytes.size(); - size_t i = 0; - // We need to (un)poison n bytes of stack shadow. Poison as many as we can - // using 64-bit stores (if we are on 64-bit arch), then poison the rest - // with 32-bit stores, then with 16-byte stores, then with 8-byte stores. - for (size_t LargeStoreSizeInBytes = ASan.LongSize / 8; - LargeStoreSizeInBytes != 0; LargeStoreSizeInBytes /= 2) { - for (; i + LargeStoreSizeInBytes - 1 < n; i += LargeStoreSizeInBytes) { - uint64_t Val = 0; - for (size_t j = 0; j < LargeStoreSizeInBytes; j++) { - if (F.getParent()->getDataLayout().isLittleEndian()) - Val |= (uint64_t)ShadowBytes[i + j] << (8 * j); - else - Val = (Val << 8) | ShadowBytes[i + j]; - } - if (!Val) continue; - Value *Ptr = IRB.CreateAdd(ShadowBase, ConstantInt::get(IntptrTy, i)); - Type *StoreTy = Type::getIntNTy(*C, LargeStoreSizeInBytes * 8); - Value *Poison = ConstantInt::get(StoreTy, DoPoison ? Val : 0); - IRB.CreateStore(Poison, IRB.CreateIntToPtr(Ptr, StoreTy->getPointerTo())); +void FunctionStackPoisoner::copyToShadowInline(ArrayRef<uint8_t> ShadowMask, + ArrayRef<uint8_t> ShadowBytes, + size_t Begin, size_t End, + IRBuilder<> &IRB, + Value *ShadowBase) { + if (Begin >= End) + return; + + const size_t LargestStoreSizeInBytes = + std::min<size_t>(sizeof(uint64_t), ASan.LongSize / 8); + + const bool IsLittleEndian = F.getParent()->getDataLayout().isLittleEndian(); + + // Poison given range in shadow using larges store size with out leading and + // trailing zeros in ShadowMask. Zeros never change, so they need neither + // poisoning nor up-poisoning. Still we don't mind if some of them get into a + // middle of a store. + for (size_t i = Begin; i < End;) { + if (!ShadowMask[i]) { + assert(!ShadowBytes[i]); + ++i; + continue; + } + + size_t StoreSizeInBytes = LargestStoreSizeInBytes; + // Fit store size into the range. + while (StoreSizeInBytes > End - i) + StoreSizeInBytes /= 2; + + // Minimize store size by trimming trailing zeros. + for (size_t j = StoreSizeInBytes - 1; j && !ShadowMask[i + j]; --j) { + while (j <= StoreSizeInBytes / 2) + StoreSizeInBytes /= 2; + } + + uint64_t Val = 0; + for (size_t j = 0; j < StoreSizeInBytes; j++) { + if (IsLittleEndian) + Val |= (uint64_t)ShadowBytes[i + j] << (8 * j); + else + Val = (Val << 8) | ShadowBytes[i + j]; + } + + Value *Ptr = IRB.CreateAdd(ShadowBase, ConstantInt::get(IntptrTy, i)); + Value *Poison = IRB.getIntN(StoreSizeInBytes * 8, Val); + IRB.CreateAlignedStore( + Poison, IRB.CreateIntToPtr(Ptr, Poison->getType()->getPointerTo()), 1); + + i += StoreSizeInBytes; + } +} + +void FunctionStackPoisoner::copyToShadow(ArrayRef<uint8_t> ShadowMask, + ArrayRef<uint8_t> ShadowBytes, + IRBuilder<> &IRB, Value *ShadowBase) { + copyToShadow(ShadowMask, ShadowBytes, 0, ShadowMask.size(), IRB, ShadowBase); +} + +void FunctionStackPoisoner::copyToShadow(ArrayRef<uint8_t> ShadowMask, + ArrayRef<uint8_t> ShadowBytes, + size_t Begin, size_t End, + IRBuilder<> &IRB, Value *ShadowBase) { + assert(ShadowMask.size() == ShadowBytes.size()); + size_t Done = Begin; + for (size_t i = Begin, j = Begin + 1; i < End; i = j++) { + if (!ShadowMask[i]) { + assert(!ShadowBytes[i]); + continue; + } + uint8_t Val = ShadowBytes[i]; + if (!AsanSetShadowFunc[Val]) + continue; + + // Skip same values. + for (; j < End && ShadowMask[j] && Val == ShadowBytes[j]; ++j) { + } + + if (j - i >= ClMaxInlinePoisoningSize) { + copyToShadowInline(ShadowMask, ShadowBytes, Done, i, IRB, ShadowBase); + IRB.CreateCall(AsanSetShadowFunc[Val], + {IRB.CreateAdd(ShadowBase, ConstantInt::get(IntptrTy, i)), + ConstantInt::get(IntptrTy, j - i)}); + Done = j; } } + + copyToShadowInline(ShadowMask, ShadowBytes, Done, End, IRB, ShadowBase); } // Fake stack allocator (asan_fake_stack.h) has 11 size classes @@ -1957,26 +2363,6 @@ static int StackMallocSizeClass(uint64_t LocalStackSize) { llvm_unreachable("impossible LocalStackSize"); } -// Set Size bytes starting from ShadowBase to kAsanStackAfterReturnMagic. -// We can not use MemSet intrinsic because it may end up calling the actual -// memset. Size is a multiple of 8. -// Currently this generates 8-byte stores on x86_64; it may be better to -// generate wider stores. -void FunctionStackPoisoner::SetShadowToStackAfterReturnInlined( - IRBuilder<> &IRB, Value *ShadowBase, int Size) { - assert(!(Size % 8)); - - // kAsanStackAfterReturnMagic is 0xf5. - const uint64_t kAsanStackAfterReturnMagic64 = 0xf5f5f5f5f5f5f5f5ULL; - - for (int i = 0; i < Size; i += 8) { - Value *p = IRB.CreateAdd(ShadowBase, ConstantInt::get(IntptrTy, i)); - IRB.CreateStore( - ConstantInt::get(IRB.getInt64Ty(), kAsanStackAfterReturnMagic64), - IRB.CreateIntToPtr(p, IRB.getInt64Ty()->getPointerTo())); - } -} - PHINode *FunctionStackPoisoner::createPHI(IRBuilder<> &IRB, Value *Cond, Value *ValueIfTrue, Instruction *ThenTerm, @@ -2015,37 +2401,39 @@ void FunctionStackPoisoner::createDynamicAllocasInitStorage() { DynamicAllocaLayout->setAlignment(32); } -void FunctionStackPoisoner::poisonStack() { - assert(AllocaVec.size() > 0 || DynamicAllocaVec.size() > 0); +void FunctionStackPoisoner::processDynamicAllocas() { + if (!ClInstrumentDynamicAllocas || DynamicAllocaVec.empty()) { + assert(DynamicAllocaPoisonCallVec.empty()); + return; + } - // Insert poison calls for lifetime intrinsics for alloca. - bool HavePoisonedStaticAllocas = false; - for (const auto &APC : AllocaPoisonCallVec) { + // Insert poison calls for lifetime intrinsics for dynamic allocas. + for (const auto &APC : DynamicAllocaPoisonCallVec) { assert(APC.InsBefore); assert(APC.AI); assert(ASan.isInterestingAlloca(*APC.AI)); - bool IsDynamicAlloca = !(*APC.AI).isStaticAlloca(); - if (!ClInstrumentAllocas && IsDynamicAlloca) - continue; + assert(!APC.AI->isStaticAlloca()); IRBuilder<> IRB(APC.InsBefore); poisonAlloca(APC.AI, APC.Size, IRB, APC.DoPoison); // Dynamic allocas will be unpoisoned unconditionally below in // unpoisonDynamicAllocas. // Flag that we need unpoison static allocas. - HavePoisonedStaticAllocas |= (APC.DoPoison && !IsDynamicAlloca); } - if (ClInstrumentAllocas && DynamicAllocaVec.size() > 0) { - // Handle dynamic allocas. - createDynamicAllocasInitStorage(); - for (auto &AI : DynamicAllocaVec) handleDynamicAllocaCall(AI); + // Handle dynamic allocas. + createDynamicAllocasInitStorage(); + for (auto &AI : DynamicAllocaVec) + handleDynamicAllocaCall(AI); + unpoisonDynamicAllocas(); +} - unpoisonDynamicAllocas(); +void FunctionStackPoisoner::processStaticAllocas() { + if (AllocaVec.empty()) { + assert(StaticAllocaPoisonCallVec.empty()); + return; } - if (AllocaVec.empty()) return; - int StackMallocIdx = -1; DebugLoc EntryDebugLocation; if (auto SP = F.getSubprogram()) @@ -2060,10 +2448,9 @@ void FunctionStackPoisoner::poisonStack() { // regular stack slots. auto InsBeforeB = InsBefore->getParent(); assert(InsBeforeB == &F.getEntryBlock()); - for (BasicBlock::iterator I(InsBefore); I != InsBeforeB->end(); ++I) - if (auto *AI = dyn_cast<AllocaInst>(I)) - if (NonInstrumentedStaticAllocaVec.count(AI) > 0) - AI->moveBefore(InsBefore); + for (auto *AI : StaticAllocasToMoveUp) + if (AI->getParent() == InsBeforeB) + AI->moveBefore(InsBefore); // If we have a call to llvm.localescape, keep it in the entry block. if (LocalEscapeCall) LocalEscapeCall->moveBefore(InsBefore); @@ -2072,16 +2459,46 @@ void FunctionStackPoisoner::poisonStack() { SVD.reserve(AllocaVec.size()); for (AllocaInst *AI : AllocaVec) { ASanStackVariableDescription D = {AI->getName().data(), - ASan.getAllocaSizeInBytes(AI), - AI->getAlignment(), AI, 0}; + ASan.getAllocaSizeInBytes(*AI), + 0, + AI->getAlignment(), + AI, + 0, + 0}; SVD.push_back(D); } + // Minimal header size (left redzone) is 4 pointers, // i.e. 32 bytes on 64-bit platforms and 16 bytes in 32-bit platforms. size_t MinHeaderSize = ASan.LongSize / 2; - ASanStackFrameLayout L; - ComputeASanStackFrameLayout(SVD, 1ULL << Mapping.Scale, MinHeaderSize, &L); - DEBUG(dbgs() << L.DescriptionString << " --- " << L.FrameSize << "\n"); + const ASanStackFrameLayout &L = + ComputeASanStackFrameLayout(SVD, 1ULL << Mapping.Scale, MinHeaderSize); + + // Build AllocaToSVDMap for ASanStackVariableDescription lookup. + DenseMap<const AllocaInst *, ASanStackVariableDescription *> AllocaToSVDMap; + for (auto &Desc : SVD) + AllocaToSVDMap[Desc.AI] = &Desc; + + // Update SVD with information from lifetime intrinsics. + for (const auto &APC : StaticAllocaPoisonCallVec) { + assert(APC.InsBefore); + assert(APC.AI); + assert(ASan.isInterestingAlloca(*APC.AI)); + assert(APC.AI->isStaticAlloca()); + + ASanStackVariableDescription &Desc = *AllocaToSVDMap[APC.AI]; + Desc.LifetimeSize = Desc.Size; + if (const DILocation *FnLoc = EntryDebugLocation.get()) { + if (const DILocation *LifetimeLoc = APC.InsBefore->getDebugLoc().get()) { + if (LifetimeLoc->getFile() == FnLoc->getFile()) + if (unsigned Line = LifetimeLoc->getLine()) + Desc.Line = std::min(Desc.Line ? Desc.Line : Line, Line); + } + } + } + + auto DescriptionString = ComputeASanStackFrameDescription(SVD); + DEBUG(dbgs() << DescriptionString << " --- " << L.FrameSize << "\n"); uint64_t LocalStackSize = L.FrameSize; bool DoStackMalloc = ClUseAfterReturn && !ASan.CompileKernel && LocalStackSize <= kMaxStackMallocSize; @@ -2164,7 +2581,7 @@ void FunctionStackPoisoner::poisonStack() { ConstantInt::get(IntptrTy, ASan.LongSize / 8)), IntptrPtrTy); GlobalVariable *StackDescriptionGlobal = - createPrivateGlobalForString(*F.getParent(), L.DescriptionString, + createPrivateGlobalForString(*F.getParent(), DescriptionString, /*AllowMerging*/ true); Value *Description = IRB.CreatePointerCast(StackDescriptionGlobal, IntptrTy); IRB.CreateStore(Description, BasePlus1); @@ -2175,19 +2592,33 @@ void FunctionStackPoisoner::poisonStack() { IntptrPtrTy); IRB.CreateStore(IRB.CreatePointerCast(&F, IntptrTy), BasePlus2); - // Poison the stack redzones at the entry. - Value *ShadowBase = ASan.memToShadow(LocalStackBase, IRB); - poisonRedZones(L.ShadowBytes, IRB, ShadowBase, true); + const auto &ShadowAfterScope = GetShadowBytesAfterScope(SVD, L); - auto UnpoisonStack = [&](IRBuilder<> &IRB) { - if (HavePoisonedStaticAllocas) { - // If we poisoned some allocas in llvm.lifetime analysis, - // unpoison whole stack frame now. - poisonAlloca(LocalStackBase, LocalStackSize, IRB, false); - } else { - poisonRedZones(L.ShadowBytes, IRB, ShadowBase, false); + // Poison the stack red zones at the entry. + Value *ShadowBase = ASan.memToShadow(LocalStackBase, IRB); + // As mask we must use most poisoned case: red zones and after scope. + // As bytes we can use either the same or just red zones only. + copyToShadow(ShadowAfterScope, ShadowAfterScope, IRB, ShadowBase); + + if (!StaticAllocaPoisonCallVec.empty()) { + const auto &ShadowInScope = GetShadowBytes(SVD, L); + + // Poison static allocas near lifetime intrinsics. + for (const auto &APC : StaticAllocaPoisonCallVec) { + const ASanStackVariableDescription &Desc = *AllocaToSVDMap[APC.AI]; + assert(Desc.Offset % L.Granularity == 0); + size_t Begin = Desc.Offset / L.Granularity; + size_t End = Begin + (APC.Size + L.Granularity - 1) / L.Granularity; + + IRBuilder<> IRB(APC.InsBefore); + copyToShadow(ShadowAfterScope, + APC.DoPoison ? ShadowAfterScope : ShadowInScope, Begin, End, + IRB, ShadowBase); } - }; + } + + SmallVector<uint8_t, 64> ShadowClean(ShadowAfterScope.size(), 0); + SmallVector<uint8_t, 64> ShadowAfterReturn; // (Un)poison the stack before all ret instructions. for (auto Ret : RetVec) { @@ -2215,8 +2646,10 @@ void FunctionStackPoisoner::poisonStack() { IRBuilder<> IRBPoison(ThenTerm); if (StackMallocIdx <= 4) { int ClassSize = kMinStackMallocSize << StackMallocIdx; - SetShadowToStackAfterReturnInlined(IRBPoison, ShadowBase, - ClassSize >> Mapping.Scale); + ShadowAfterReturn.resize(ClassSize / L.Granularity, + kAsanStackUseAfterReturnMagic); + copyToShadow(ShadowAfterReturn, ShadowAfterReturn, IRBPoison, + ShadowBase); Value *SavedFlagPtrPtr = IRBPoison.CreateAdd( FakeStack, ConstantInt::get(IntptrTy, ClassSize - ASan.LongSize / 8)); @@ -2233,9 +2666,9 @@ void FunctionStackPoisoner::poisonStack() { } IRBuilder<> IRBElse(ElseTerm); - UnpoisonStack(IRBElse); + copyToShadow(ShadowAfterScope, ShadowClean, IRBElse, ShadowBase); } else { - UnpoisonStack(IRBRet); + copyToShadow(ShadowAfterScope, ShadowClean, IRBRet, ShadowBase); } } @@ -2264,7 +2697,7 @@ void FunctionStackPoisoner::poisonAlloca(Value *V, uint64_t Size, AllocaInst *FunctionStackPoisoner::findAllocaForValue(Value *V) { if (AllocaInst *AI = dyn_cast<AllocaInst>(V)) - // We're intested only in allocas we can handle. + // We're interested only in allocas we can handle. return ASan.isInterestingAlloca(*AI) ? AI : nullptr; // See if we've already calculated (or started to calculate) alloca for a // given value. @@ -2286,6 +2719,10 @@ AllocaInst *FunctionStackPoisoner::findAllocaForValue(Value *V) { return nullptr; Res = IncValueAI; } + } else if (GetElementPtrInst *EP = dyn_cast<GetElementPtrInst>(V)) { + Res = findAllocaForValue(EP->getPointerOperand()); + } else { + DEBUG(dbgs() << "Alloca search canceled on unknown instruction: " << *V << "\n"); } if (Res) AllocaForValue[V] = Res; return Res; diff --git a/contrib/llvm/lib/Transforms/Instrumentation/CFGMST.h b/contrib/llvm/lib/Transforms/Instrumentation/CFGMST.h index 3cd7351..3802f9f 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/CFGMST.h +++ b/contrib/llvm/lib/Transforms/Instrumentation/CFGMST.h @@ -78,6 +78,14 @@ public: return *It->second.get(); } + // Give BB, return the auxiliary information if it's available. + BBInfo *findBBInfo(const BasicBlock *BB) const { + auto It = BBInfos.find(BB); + if (It == BBInfos.end()) + return nullptr; + return It->second.get(); + } + // Traverse the CFG using a stack. Find all the edges and assign the weight. // Edges with large weight will be put into MST first so they are less likely // to be instrumented. diff --git a/contrib/llvm/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp b/contrib/llvm/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp index fb80f87..05eba6c 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp @@ -99,12 +99,23 @@ static const char *const EsanWhichToolName = "__esan_which_tool"; // FIXME: Try to place these shadow constants, the names of the __esan_* // interface functions, and the ToolType enum into a header shared between // llvm and compiler-rt. -static const uint64_t ShadowMask = 0x00000fffffffffffull; -static const uint64_t ShadowOffs[3] = { // Indexed by scale - 0x0000130000000000ull, - 0x0000220000000000ull, - 0x0000440000000000ull, +struct ShadowMemoryParams { + uint64_t ShadowMask; + uint64_t ShadowOffs[3]; }; + +static const ShadowMemoryParams ShadowParams47 = { + 0x00000fffffffffffull, + { + 0x0000130000000000ull, 0x0000220000000000ull, 0x0000440000000000ull, + }}; + +static const ShadowMemoryParams ShadowParams40 = { + 0x0fffffffffull, + { + 0x1300000000ull, 0x2200000000ull, 0x4400000000ull, + }}; + // This array is indexed by the ToolType enum. static const int ShadowScale[] = { 0, // ESAN_None. @@ -154,7 +165,7 @@ public: EfficiencySanitizer( const EfficiencySanitizerOptions &Opts = EfficiencySanitizerOptions()) : ModulePass(ID), Options(OverrideOptionsFromCL(Opts)) {} - const char *getPassName() const override; + StringRef getPassName() const override; void getAnalysisUsage(AnalysisUsage &AU) const override; bool runOnModule(Module &M) override; static char ID; @@ -219,6 +230,7 @@ private: // Remember the counter variable for each struct type to avoid // recomputing the variable name later during instrumentation. std::map<Type *, GlobalVariable *> StructTyMap; + ShadowMemoryParams ShadowParams; }; } // namespace @@ -231,7 +243,7 @@ INITIALIZE_PASS_END( EfficiencySanitizer, "esan", "EfficiencySanitizer: finds performance issues.", false, false) -const char *EfficiencySanitizer::getPassName() const { +StringRef EfficiencySanitizer::getPassName() const { return "EfficiencySanitizer"; } @@ -301,21 +313,21 @@ void EfficiencySanitizer::createStructCounterName( else NameStr += "struct.anon"; // We allow the actual size of the StructCounterName to be larger than - // MaxStructCounterNameSize and append #NumFields and at least one + // MaxStructCounterNameSize and append $NumFields and at least one // field type id. - // Append #NumFields. - NameStr += "#"; + // Append $NumFields. + NameStr += "$"; Twine(StructTy->getNumElements()).toVector(NameStr); // Append struct field type ids in the reverse order. for (int i = StructTy->getNumElements() - 1; i >= 0; --i) { - NameStr += "#"; + NameStr += "$"; Twine(StructTy->getElementType(i)->getTypeID()).toVector(NameStr); if (NameStr.size() >= MaxStructCounterNameSize) break; } if (StructTy->isLiteral()) { - // End with # for literal struct. - NameStr += "#"; + // End with $ for literal struct. + NameStr += "$"; } } @@ -528,6 +540,13 @@ void EfficiencySanitizer::createDestructor(Module &M, Constant *ToolInfoArg) { } bool EfficiencySanitizer::initOnModule(Module &M) { + + Triple TargetTriple(M.getTargetTriple()); + if (TargetTriple.getArch() == Triple::mips64 || TargetTriple.getArch() == Triple::mips64el) + ShadowParams = ShadowParams40; + else + ShadowParams = ShadowParams47; + Ctx = &M.getContext(); const DataLayout &DL = M.getDataLayout(); IRBuilder<> IRB(M.getContext()); @@ -559,13 +578,13 @@ bool EfficiencySanitizer::initOnModule(Module &M) { Value *EfficiencySanitizer::appToShadow(Value *Shadow, IRBuilder<> &IRB) { // Shadow = ((App & Mask) + Offs) >> Scale - Shadow = IRB.CreateAnd(Shadow, ConstantInt::get(IntptrTy, ShadowMask)); + Shadow = IRB.CreateAnd(Shadow, ConstantInt::get(IntptrTy, ShadowParams.ShadowMask)); uint64_t Offs; int Scale = ShadowScale[Options.ToolType]; if (Scale <= 2) - Offs = ShadowOffs[Scale]; + Offs = ShadowParams.ShadowOffs[Scale]; else - Offs = ShadowOffs[0] << Scale; + Offs = ShadowParams.ShadowOffs[0] << Scale; Shadow = IRB.CreateAdd(Shadow, ConstantInt::get(IntptrTy, Offs)); if (Scale > 0) Shadow = IRB.CreateLShr(Shadow, Scale); diff --git a/contrib/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/contrib/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp index b4070b6..56d0f5e 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/GCOVProfiling.cpp @@ -118,7 +118,8 @@ private: Function *insertFlush(ArrayRef<std::pair<GlobalVariable *, MDNode *>>); void insertIndirectCounterIncrement(); - std::string mangleName(const DICompileUnit *CU, const char *NewStem); + enum class GCovFileType { GCNO, GCDA }; + std::string mangleName(const DICompileUnit *CU, GCovFileType FileType); GCOVOptions Options; @@ -141,7 +142,7 @@ public: : ModulePass(ID), Profiler(Opts) { initializeGCOVProfilerLegacyPassPass(*PassRegistry::getPassRegistry()); } - const char *getPassName() const override { return "GCOV Profiler"; } + StringRef getPassName() const override { return "GCOV Profiler"; } bool runOnModule(Module &M) override { return Profiler.runOnModule(M); } @@ -251,11 +252,7 @@ namespace { class GCOVBlock : public GCOVRecord { public: GCOVLines &getFile(StringRef Filename) { - GCOVLines *&Lines = LinesByFile[Filename]; - if (!Lines) { - Lines = new GCOVLines(Filename, os); - } - return *Lines; + return LinesByFile.try_emplace(Filename, Filename, os).first->second; } void addEdge(GCOVBlock &Successor) { @@ -264,9 +261,9 @@ namespace { void writeOut() { uint32_t Len = 3; - SmallVector<StringMapEntry<GCOVLines *> *, 32> SortedLinesByFile; + SmallVector<StringMapEntry<GCOVLines> *, 32> SortedLinesByFile; for (auto &I : LinesByFile) { - Len += I.second->length(); + Len += I.second.length(); SortedLinesByFile.push_back(&I); } @@ -274,21 +271,17 @@ namespace { write(Len); write(Number); - std::sort(SortedLinesByFile.begin(), SortedLinesByFile.end(), - [](StringMapEntry<GCOVLines *> *LHS, - StringMapEntry<GCOVLines *> *RHS) { - return LHS->getKey() < RHS->getKey(); - }); + std::sort( + SortedLinesByFile.begin(), SortedLinesByFile.end(), + [](StringMapEntry<GCOVLines> *LHS, StringMapEntry<GCOVLines> *RHS) { + return LHS->getKey() < RHS->getKey(); + }); for (auto &I : SortedLinesByFile) - I->getValue()->writeOut(); + I->getValue().writeOut(); write(0); write(0); } - ~GCOVBlock() { - DeleteContainerSeconds(LinesByFile); - } - GCOVBlock(const GCOVBlock &RHS) : GCOVRecord(RHS), Number(RHS.Number) { // Only allow copy before edges and lines have been added. After that, // there are inter-block pointers (eg: edges) that won't take kindly to @@ -306,7 +299,7 @@ namespace { } uint32_t Number; - StringMap<GCOVLines *> LinesByFile; + StringMap<GCOVLines> LinesByFile; SmallVector<GCOVBlock *, 4> OutEdges; }; @@ -426,24 +419,40 @@ namespace { } std::string GCOVProfiler::mangleName(const DICompileUnit *CU, - const char *NewStem) { + GCovFileType OutputType) { + bool Notes = OutputType == GCovFileType::GCNO; + if (NamedMDNode *GCov = M->getNamedMetadata("llvm.gcov")) { for (int i = 0, e = GCov->getNumOperands(); i != e; ++i) { MDNode *N = GCov->getOperand(i); - if (N->getNumOperands() != 2) continue; - MDString *GCovFile = dyn_cast<MDString>(N->getOperand(0)); - MDNode *CompileUnit = dyn_cast<MDNode>(N->getOperand(1)); - if (!GCovFile || !CompileUnit) continue; - if (CompileUnit == CU) { - SmallString<128> Filename = GCovFile->getString(); - sys::path::replace_extension(Filename, NewStem); - return Filename.str(); + bool ThreeElement = N->getNumOperands() == 3; + if (!ThreeElement && N->getNumOperands() != 2) + continue; + if (dyn_cast<MDNode>(N->getOperand(ThreeElement ? 2 : 1)) != CU) + continue; + + if (ThreeElement) { + // These nodes have no mangling to apply, it's stored mangled in the + // bitcode. + MDString *NotesFile = dyn_cast<MDString>(N->getOperand(0)); + MDString *DataFile = dyn_cast<MDString>(N->getOperand(1)); + if (!NotesFile || !DataFile) + continue; + return Notes ? NotesFile->getString() : DataFile->getString(); } + + MDString *GCovFile = dyn_cast<MDString>(N->getOperand(0)); + if (!GCovFile) + continue; + + SmallString<128> Filename = GCovFile->getString(); + sys::path::replace_extension(Filename, Notes ? "gcno" : "gcda"); + return Filename.str(); } } SmallString<128> Filename = CU->getFilename(); - sys::path::replace_extension(Filename, NewStem); + sys::path::replace_extension(Filename, Notes ? "gcno" : "gcda"); StringRef FName = sys::path::filename(Filename); SmallString<128> CurPath; if (sys::fs::current_path(CurPath)) return FName; @@ -461,7 +470,7 @@ bool GCOVProfiler::runOnModule(Module &M) { } PreservedAnalyses GCOVProfilerPass::run(Module &M, - AnalysisManager<Module> &AM) { + ModuleAnalysisManager &AM) { GCOVProfiler Profiler(GCOVOpts); @@ -509,7 +518,7 @@ void GCOVProfiler::emitProfileNotes() { continue; std::error_code EC; - raw_fd_ostream out(mangleName(CU, "gcno"), EC, sys::fs::F_None); + raw_fd_ostream out(mangleName(CU, GCovFileType::GCNO), EC, sys::fs::F_None); std::string EdgeDestinations; unsigned FunctionIdent = 0; @@ -857,7 +866,7 @@ Function *GCOVProfiler::insertCounterWriteout( if (CU->getDWOId()) continue; - std::string FilenameGcda = mangleName(CU, "gcda"); + std::string FilenameGcda = mangleName(CU, GCovFileType::GCDA); uint32_t CfgChecksum = FileChecksums.empty() ? 0 : FileChecksums[i]; Builder.CreateCall(StartFile, {Builder.CreateGlobalStringPtr(FilenameGcda), diff --git a/contrib/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/contrib/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index 202b94b..1ba13bd 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -13,29 +13,38 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/Triple.h" -#include "llvm/Analysis/CFG.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" #include "llvm/Analysis/IndirectCallPromotionAnalysis.h" #include "llvm/Analysis/IndirectCallSiteVisitor.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" -#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" #include "llvm/Pass.h" -#include "llvm/ProfileData/InstrProfReader.h" +#include "llvm/PassRegistry.h" +#include "llvm/PassSupport.h" +#include "llvm/ProfileData/InstrProf.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/PGOInstrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include <string> -#include <utility> +#include <cassert> +#include <cstdint> #include <vector> using namespace llvm; @@ -102,9 +111,7 @@ public: *PassRegistry::getPassRegistry()); } - const char *getPassName() const override { - return "PGOIndirectCallPromotion"; - } + StringRef getPassName() const override { return "PGOIndirectCallPromotion"; } private: bool runOnModule(Module &M) override; @@ -208,6 +215,7 @@ public: ICallPromotionFunc(Function &Func, Module *Modu, InstrProfSymtab *Symtab) : F(Func), M(Modu), Symtab(Symtab) { } + bool processFunction(); }; } // end anonymous namespace @@ -474,7 +482,7 @@ static Instruction *createDirectCallInst(const Instruction *Inst, NewInst); // Clear the value profile data. - NewInst->setMetadata(LLVMContext::MD_prof, 0); + NewInst->setMetadata(LLVMContext::MD_prof, nullptr); CallSite NewCS(NewInst); FunctionType *DirectCalleeType = DirectCallee->getFunctionType(); unsigned ParamNum = DirectCalleeType->getFunctionNumParams(); @@ -610,7 +618,7 @@ bool ICallPromotionFunc::processFunction() { Changed = true; // Adjust the MD.prof metadata. First delete the old one. - I->setMetadata(LLVMContext::MD_prof, 0); + I->setMetadata(LLVMContext::MD_prof, nullptr); // If all promoted, we don't need the MD.prof metadata. if (TotalCount == 0 || NumPromoted == NumVals) continue; @@ -653,7 +661,7 @@ bool PGOIndirectCallPromotionLegacyPass::runOnModule(Module &M) { return promoteIndirectCalls(M, InLTO | ICPLTOMode); } -PreservedAnalyses PGOIndirectCallPromotion::run(Module &M, AnalysisManager<Module> &AM) { +PreservedAnalyses PGOIndirectCallPromotion::run(Module &M, ModuleAnalysisManager &AM) { if (!promoteIndirectCalls(M, InLTO | ICPLTOMode)) return PreservedAnalyses::all(); diff --git a/contrib/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/contrib/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp index b11c6be..adea7e7 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -15,6 +15,7 @@ #include "llvm/Transforms/InstrProfiling.h" #include "llvm/ADT/Triple.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" @@ -31,6 +32,11 @@ cl::opt<bool> DoNameCompression("enable-name-compression", cl::desc("Enable name string compression"), cl::init(true)); +cl::opt<bool> DoHashBasedCounterSplit( + "hash-based-counter-split", + cl::desc("Rename counter variable of a comdat function based on cfg hash"), + cl::init(true)); + cl::opt<bool> ValueProfileStaticAlloc( "vp-static-alloc", cl::desc("Do static counter allocation for value profiler"), @@ -53,30 +59,38 @@ public: InstrProfilingLegacyPass() : ModulePass(ID), InstrProf() {} InstrProfilingLegacyPass(const InstrProfOptions &Options) : ModulePass(ID), InstrProf(Options) {} - const char *getPassName() const override { + StringRef getPassName() const override { return "Frontend instrumentation-based coverage lowering"; } - bool runOnModule(Module &M) override { return InstrProf.run(M); } + bool runOnModule(Module &M) override { + return InstrProf.run(M, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI()); + } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); } }; } // anonymous namespace -PreservedAnalyses InstrProfiling::run(Module &M, AnalysisManager<Module> &AM) { - if (!run(M)) +PreservedAnalyses InstrProfiling::run(Module &M, ModuleAnalysisManager &AM) { + auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); + if (!run(M, TLI)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); } char InstrProfilingLegacyPass::ID = 0; -INITIALIZE_PASS(InstrProfilingLegacyPass, "instrprof", - "Frontend instrumentation-based coverage lowering.", false, - false) +INITIALIZE_PASS_BEGIN( + InstrProfilingLegacyPass, "instrprof", + "Frontend instrumentation-based coverage lowering.", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END( + InstrProfilingLegacyPass, "instrprof", + "Frontend instrumentation-based coverage lowering.", false, false) ModulePass * llvm::createInstrProfilingLegacyPass(const InstrProfOptions &Options) { @@ -107,10 +121,18 @@ StringRef InstrProfiling::getCoverageSection() const { return getInstrProfCoverageSectionName(isMachO()); } -bool InstrProfiling::run(Module &M) { +static InstrProfIncrementInst *castToIncrementInst(Instruction *Instr) { + InstrProfIncrementInst *Inc = dyn_cast<InstrProfIncrementInstStep>(Instr); + if (Inc) + return Inc; + return dyn_cast<InstrProfIncrementInst>(Instr); +} + +bool InstrProfiling::run(Module &M, const TargetLibraryInfo &TLI) { bool MadeChange = false; this->M = &M; + this->TLI = &TLI; NamesVar = nullptr; NamesSize = 0; ProfileDataMap.clear(); @@ -138,7 +160,8 @@ bool InstrProfiling::run(Module &M) { for (BasicBlock &BB : F) for (auto I = BB.begin(), E = BB.end(); I != E;) { auto Instr = I++; - if (auto *Inc = dyn_cast<InstrProfIncrementInst>(Instr)) { + InstrProfIncrementInst *Inc = castToIncrementInst(&*Instr); + if (Inc) { lowerIncrement(Inc); MadeChange = true; } else if (auto *Ind = dyn_cast<InstrProfValueProfileInst>(Instr)) { @@ -165,7 +188,8 @@ bool InstrProfiling::run(Module &M) { return true; } -static Constant *getOrInsertValueProfilingCall(Module &M) { +static Constant *getOrInsertValueProfilingCall(Module &M, + const TargetLibraryInfo &TLI) { LLVMContext &Ctx = M.getContext(); auto *ReturnTy = Type::getVoidTy(M.getContext()); Type *ParamTypes[] = { @@ -174,8 +198,13 @@ static Constant *getOrInsertValueProfilingCall(Module &M) { }; auto *ValueProfilingCallTy = FunctionType::get(ReturnTy, makeArrayRef(ParamTypes), false); - return M.getOrInsertFunction(getInstrProfValueProfFuncName(), - ValueProfilingCallTy); + Constant *Res = M.getOrInsertFunction(getInstrProfValueProfFuncName(), + ValueProfilingCallTy); + if (Function *FunRes = dyn_cast<Function>(Res)) { + if (auto AK = TLI.getExtAttrForI32Param(false)) + FunRes->addAttribute(3, AK); + } + return Res; } void InstrProfiling::computeNumValueSiteCounts(InstrProfValueProfileInst *Ind) { @@ -209,8 +238,11 @@ void InstrProfiling::lowerValueProfileInst(InstrProfValueProfileInst *Ind) { Value *Args[3] = {Ind->getTargetValue(), Builder.CreateBitCast(DataVar, Builder.getInt8PtrTy()), Builder.getInt32(Index)}; - Ind->replaceAllUsesWith( - Builder.CreateCall(getOrInsertValueProfilingCall(*M), Args)); + CallInst *Call = Builder.CreateCall(getOrInsertValueProfilingCall(*M, *TLI), + Args); + if (auto AK = TLI->getExtAttrForI32Param(false)) + Call->addAttribute(3, AK); + Ind->replaceAllUsesWith(Call); Ind->eraseFromParent(); } @@ -221,7 +253,7 @@ void InstrProfiling::lowerIncrement(InstrProfIncrementInst *Inc) { uint64_t Index = Inc->getIndex()->getZExtValue(); Value *Addr = Builder.CreateConstInBoundsGEP2_64(Counters, 0, Index); Value *Count = Builder.CreateLoad(Addr, "pgocount"); - Count = Builder.CreateAdd(Count, Builder.getInt64(1)); + Count = Builder.CreateAdd(Count, Inc->getStep()); Inc->replaceAllUsesWith(Builder.CreateStore(Count, Addr)); Inc->eraseFromParent(); } @@ -245,7 +277,16 @@ void InstrProfiling::lowerCoverageData(GlobalVariable *CoverageNamesVar) { static std::string getVarName(InstrProfIncrementInst *Inc, StringRef Prefix) { StringRef NamePrefix = getInstrProfNameVarPrefix(); StringRef Name = Inc->getName()->getName().substr(NamePrefix.size()); - return (Prefix + Name).str(); + Function *F = Inc->getParent()->getParent(); + Module *M = F->getParent(); + if (!DoHashBasedCounterSplit || !isIRPGOFlagSet(M) || + !canRenameComdatFunc(*F)) + return (Prefix + Name).str(); + uint64_t FuncHash = Inc->getHash()->getZExtValue(); + SmallVector<char, 24> HashPostfix; + if (Name.endswith((Twine(".") + Twine(FuncHash)).toStringRef(HashPostfix))) + return (Prefix + Name).str(); + return (Prefix + Name + "." + Twine(FuncHash)).str(); } static inline bool shouldRecordFunctionAddr(Function *F) { @@ -268,33 +309,6 @@ static inline bool shouldRecordFunctionAddr(Function *F) { return F->hasAddressTaken() || F->hasLinkOnceLinkage(); } -static inline bool needsComdatForCounter(Function &F, Module &M) { - - if (F.hasComdat()) - return true; - - Triple TT(M.getTargetTriple()); - if (!TT.isOSBinFormatELF()) - return false; - - // See createPGOFuncNameVar for more details. To avoid link errors, profile - // counters for function with available_externally linkage needs to be changed - // to linkonce linkage. On ELF based systems, this leads to weak symbols to be - // created. Without using comdat, duplicate entries won't be removed by the - // linker leading to increased data segement size and raw profile size. Even - // worse, since the referenced counter from profile per-function data object - // will be resolved to the common strong definition, the profile counts for - // available_externally functions will end up being duplicated in raw profile - // data. This can result in distorted profile as the counts of those dups - // will be accumulated by the profile merger. - GlobalValue::LinkageTypes Linkage = F.getLinkage(); - if (Linkage != GlobalValue::ExternalWeakLinkage && - Linkage != GlobalValue::AvailableExternallyLinkage) - return false; - - return true; -} - static inline Comdat *getOrCreateProfileComdat(Module &M, Function &F, InstrProfIncrementInst *Inc) { if (!needsComdatForCounter(F, M)) @@ -572,38 +586,30 @@ void InstrProfiling::emitRuntimeHook() { } void InstrProfiling::emitUses() { - if (UsedVars.empty()) - return; - - GlobalVariable *LLVMUsed = M->getGlobalVariable("llvm.used"); - std::vector<Constant *> MergedVars; - if (LLVMUsed) { - // Collect the existing members of llvm.used. - ConstantArray *Inits = cast<ConstantArray>(LLVMUsed->getInitializer()); - for (unsigned I = 0, E = Inits->getNumOperands(); I != E; ++I) - MergedVars.push_back(Inits->getOperand(I)); - LLVMUsed->eraseFromParent(); - } - - Type *i8PTy = Type::getInt8PtrTy(M->getContext()); - // Add uses for our data. - for (auto *Value : UsedVars) - MergedVars.push_back( - ConstantExpr::getBitCast(cast<Constant>(Value), i8PTy)); - - // Recreate llvm.used. - ArrayType *ATy = ArrayType::get(i8PTy, MergedVars.size()); - LLVMUsed = - new GlobalVariable(*M, ATy, false, GlobalValue::AppendingLinkage, - ConstantArray::get(ATy, MergedVars), "llvm.used"); - LLVMUsed->setSection("llvm.metadata"); + if (!UsedVars.empty()) + appendToUsed(*M, UsedVars); } void InstrProfiling::emitInitialization() { - std::string InstrProfileOutput = Options.InstrProfileOutput; + StringRef InstrProfileOutput = Options.InstrProfileOutput; + + if (!InstrProfileOutput.empty()) { + // Create variable for profile name. + Constant *ProfileNameConst = + ConstantDataArray::getString(M->getContext(), InstrProfileOutput, true); + GlobalVariable *ProfileNameVar = new GlobalVariable( + *M, ProfileNameConst->getType(), true, GlobalValue::WeakAnyLinkage, + ProfileNameConst, INSTR_PROF_QUOTE(INSTR_PROF_PROFILE_NAME_VAR)); + Triple TT(M->getTargetTriple()); + if (TT.supportsCOMDAT()) { + ProfileNameVar->setLinkage(GlobalValue::ExternalLinkage); + ProfileNameVar->setComdat(M->getOrInsertComdat( + StringRef(INSTR_PROF_QUOTE(INSTR_PROF_PROFILE_NAME_VAR)))); + } + } Constant *RegisterF = M->getFunction(getInstrProfRegFuncsName()); - if (!RegisterF && InstrProfileOutput.empty()) + if (!RegisterF) return; // Create the initialization function. @@ -620,21 +626,6 @@ void InstrProfiling::emitInitialization() { IRBuilder<> IRB(BasicBlock::Create(M->getContext(), "", F)); if (RegisterF) IRB.CreateCall(RegisterF, {}); - if (!InstrProfileOutput.empty()) { - auto *Int8PtrTy = Type::getInt8PtrTy(M->getContext()); - auto *SetNameTy = FunctionType::get(VoidTy, Int8PtrTy, false); - auto *SetNameF = Function::Create(SetNameTy, GlobalValue::ExternalLinkage, - getInstrProfFileOverriderFuncName(), M); - - // Create variable for profile name. - Constant *ProfileNameConst = - ConstantDataArray::getString(M->getContext(), InstrProfileOutput, true); - GlobalVariable *ProfileName = - new GlobalVariable(*M, ProfileNameConst->getType(), true, - GlobalValue::PrivateLinkage, ProfileNameConst); - - IRB.CreateCall(SetNameF, IRB.CreatePointerCast(ProfileName, Int8PtrTy)); - } IRB.CreateRetVoid(); appendToGlobalCtors(*M, F, 0); diff --git a/contrib/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/contrib/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index 970f9ab..fafb0fc 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -242,8 +242,8 @@ static const MemoryMapParams Linux_X86_64_MemoryMapParams = { // mips64 Linux static const MemoryMapParams Linux_MIPS64_MemoryMapParams = { - 0x004000000000, // AndMask - 0, // XorMask (not used) + 0, // AndMask (not used) + 0x008000000000, // XorMask 0, // ShadowBase (not used) 0x002000000000, // OriginBase }; @@ -312,11 +312,12 @@ static const PlatformMemoryMapParams FreeBSD_X86_MemoryMapParams = { /// uninitialized reads. class MemorySanitizer : public FunctionPass { public: - MemorySanitizer(int TrackOrigins = 0) + MemorySanitizer(int TrackOrigins = 0, bool Recover = false) : FunctionPass(ID), TrackOrigins(std::max(TrackOrigins, (int)ClTrackOrigins)), + Recover(Recover || ClKeepGoing), WarningFn(nullptr) {} - const char *getPassName() const override { return "MemorySanitizer"; } + StringRef getPassName() const override { return "MemorySanitizer"; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<TargetLibraryInfoWrapperPass>(); } @@ -329,6 +330,7 @@ class MemorySanitizer : public FunctionPass { /// \brief Track origins (allocation points) of uninitialized values. int TrackOrigins; + bool Recover; LLVMContext *C; Type *IntptrTy; @@ -395,8 +397,8 @@ INITIALIZE_PASS_END( MemorySanitizer, "msan", "MemorySanitizer: detects uninitialized reads.", false, false) -FunctionPass *llvm::createMemorySanitizerPass(int TrackOrigins) { - return new MemorySanitizer(TrackOrigins); +FunctionPass *llvm::createMemorySanitizerPass(int TrackOrigins, bool Recover) { + return new MemorySanitizer(TrackOrigins, Recover); } /// \brief Create a non-const global initialized with the given string. @@ -421,8 +423,8 @@ void MemorySanitizer::initializeCallbacks(Module &M) { // Create the callback. // FIXME: this function should have "Cold" calling conv, // which is not yet implemented. - StringRef WarningFnName = ClKeepGoing ? "__msan_warning" - : "__msan_warning_noreturn"; + StringRef WarningFnName = Recover ? "__msan_warning" + : "__msan_warning_noreturn"; WarningFn = M.getOrInsertFunction(WarningFnName, IRB.getVoidTy(), nullptr); for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; @@ -566,9 +568,9 @@ bool MemorySanitizer::doInitialization(Module &M) { new GlobalVariable(M, IRB.getInt32Ty(), true, GlobalValue::WeakODRLinkage, IRB.getInt32(TrackOrigins), "__msan_track_origins"); - if (ClKeepGoing) + if (Recover) new GlobalVariable(M, IRB.getInt32Ty(), true, GlobalValue::WeakODRLinkage, - IRB.getInt32(ClKeepGoing), "__msan_keep_going"); + IRB.getInt32(Recover), "__msan_keep_going"); return true; } @@ -792,7 +794,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } IRB.CreateCall(MS.WarningFn, {}); IRB.CreateCall(MS.EmptyAsm, {}); - // FIXME: Insert UnreachableInst if !ClKeepGoing? + // FIXME: Insert UnreachableInst if !MS.Recover? // This may invalidate some of the following checks and needs to be done // at the very end. } @@ -815,7 +817,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { getCleanShadow(ConvertedShadow), "_mscmp"); Instruction *CheckTerm = SplitBlockAndInsertIfThen( Cmp, OrigIns, - /* Unreachable */ !ClKeepGoing, MS.ColdCallWeights); + /* Unreachable */ !MS.Recover, MS.ColdCallWeights); IRB.SetInsertPoint(CheckTerm); if (MS.TrackOrigins) { @@ -2360,6 +2362,29 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { case llvm::Intrinsic::x86_sse_cvttps2pi: handleVectorConvertIntrinsic(I, 2); break; + + case llvm::Intrinsic::x86_avx512_psll_w_512: + case llvm::Intrinsic::x86_avx512_psll_d_512: + case llvm::Intrinsic::x86_avx512_psll_q_512: + case llvm::Intrinsic::x86_avx512_pslli_w_512: + case llvm::Intrinsic::x86_avx512_pslli_d_512: + case llvm::Intrinsic::x86_avx512_pslli_q_512: + case llvm::Intrinsic::x86_avx512_psrl_w_512: + case llvm::Intrinsic::x86_avx512_psrl_d_512: + case llvm::Intrinsic::x86_avx512_psrl_q_512: + case llvm::Intrinsic::x86_avx512_psra_w_512: + case llvm::Intrinsic::x86_avx512_psra_d_512: + case llvm::Intrinsic::x86_avx512_psra_q_512: + case llvm::Intrinsic::x86_avx512_psrli_w_512: + case llvm::Intrinsic::x86_avx512_psrli_d_512: + case llvm::Intrinsic::x86_avx512_psrli_q_512: + case llvm::Intrinsic::x86_avx512_psrai_w_512: + case llvm::Intrinsic::x86_avx512_psrai_d_512: + case llvm::Intrinsic::x86_avx512_psrai_q_512: + case llvm::Intrinsic::x86_avx512_psra_q_256: + case llvm::Intrinsic::x86_avx512_psra_q_128: + case llvm::Intrinsic::x86_avx512_psrai_q_256: + case llvm::Intrinsic::x86_avx512_psrai_q_128: case llvm::Intrinsic::x86_avx2_psll_w: case llvm::Intrinsic::x86_avx2_psll_d: case llvm::Intrinsic::x86_avx2_psll_q: @@ -2412,14 +2437,22 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { break; case llvm::Intrinsic::x86_avx2_psllv_d: case llvm::Intrinsic::x86_avx2_psllv_d_256: + case llvm::Intrinsic::x86_avx512_psllv_d_512: case llvm::Intrinsic::x86_avx2_psllv_q: case llvm::Intrinsic::x86_avx2_psllv_q_256: + case llvm::Intrinsic::x86_avx512_psllv_q_512: case llvm::Intrinsic::x86_avx2_psrlv_d: case llvm::Intrinsic::x86_avx2_psrlv_d_256: + case llvm::Intrinsic::x86_avx512_psrlv_d_512: case llvm::Intrinsic::x86_avx2_psrlv_q: case llvm::Intrinsic::x86_avx2_psrlv_q_256: + case llvm::Intrinsic::x86_avx512_psrlv_q_512: case llvm::Intrinsic::x86_avx2_psrav_d: case llvm::Intrinsic::x86_avx2_psrav_d_256: + case llvm::Intrinsic::x86_avx512_psrav_d_512: + case llvm::Intrinsic::x86_avx512_psrav_q_128: + case llvm::Intrinsic::x86_avx512_psrav_q_256: + case llvm::Intrinsic::x86_avx512_psrav_q_512: handleVectorShiftIntrinsic(I, /* Variable */ true); break; diff --git a/contrib/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/contrib/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index f54d8ad..04f9a64 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -51,6 +51,7 @@ #include "llvm/Transforms/PGOInstrumentation.h" #include "CFGMST.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Triple.h" #include "llvm/Analysis/BlockFrequencyInfo.h" @@ -59,6 +60,7 @@ #include "llvm/Analysis/IndirectCallSiteVisitor.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" @@ -75,6 +77,7 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include <algorithm> #include <string> +#include <unordered_map> #include <utility> #include <vector> @@ -83,6 +86,7 @@ using namespace llvm; #define DEBUG_TYPE "pgo-instrumentation" STATISTIC(NumOfPGOInstrument, "Number of edges instrumented."); +STATISTIC(NumOfPGOSelectInsts, "Number of select instruction instrumented."); STATISTIC(NumOfPGOEdge, "Number of edges."); STATISTIC(NumOfPGOBB, "Number of basic-blocks."); STATISTIC(NumOfPGOSplit, "Number of critical edge splits."); @@ -112,17 +116,89 @@ static cl::opt<unsigned> MaxNumAnnotations( cl::desc("Max number of annotations for a single indirect " "call callsite")); +// Command line option to control appending FunctionHash to the name of a COMDAT +// function. This is to avoid the hash mismatch caused by the preinliner. +static cl::opt<bool> DoComdatRenaming( + "do-comdat-renaming", cl::init(false), cl::Hidden, + cl::desc("Append function hash to the name of COMDAT function to avoid " + "function hash mismatch due to the preinliner")); + // Command line option to enable/disable the warning about missing profile // information. -static cl::opt<bool> NoPGOWarnMissing("no-pgo-warn-missing", cl::init(false), - cl::Hidden); +static cl::opt<bool> PGOWarnMissing("pgo-warn-missing-function", + cl::init(false), + cl::Hidden); // Command line option to enable/disable the warning about a hash mismatch in // the profile data. static cl::opt<bool> NoPGOWarnMismatch("no-pgo-warn-mismatch", cl::init(false), cl::Hidden); +// Command line option to enable/disable the warning about a hash mismatch in +// the profile data for Comdat functions, which often turns out to be false +// positive due to the pre-instrumentation inline. +static cl::opt<bool> NoPGOWarnMismatchComdat("no-pgo-warn-mismatch-comdat", + cl::init(true), cl::Hidden); + +// Command line option to enable/disable select instruction instrumentation. +static cl::opt<bool> PGOInstrSelect("pgo-instr-select", cl::init(true), + cl::Hidden); namespace { + +/// The select instruction visitor plays three roles specified +/// by the mode. In \c VM_counting mode, it simply counts the number of +/// select instructions. In \c VM_instrument mode, it inserts code to count +/// the number times TrueValue of select is taken. In \c VM_annotate mode, +/// it reads the profile data and annotate the select instruction with metadata. +enum VisitMode { VM_counting, VM_instrument, VM_annotate }; +class PGOUseFunc; + +/// Instruction Visitor class to visit select instructions. +struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> { + Function &F; + unsigned NSIs = 0; // Number of select instructions instrumented. + VisitMode Mode = VM_counting; // Visiting mode. + unsigned *CurCtrIdx = nullptr; // Pointer to current counter index. + unsigned TotalNumCtrs = 0; // Total number of counters + GlobalVariable *FuncNameVar = nullptr; + uint64_t FuncHash = 0; + PGOUseFunc *UseFunc = nullptr; + + SelectInstVisitor(Function &Func) : F(Func) {} + + void countSelects(Function &Func) { + Mode = VM_counting; + visit(Func); + } + // Visit the IR stream and instrument all select instructions. \p + // Ind is a pointer to the counter index variable; \p TotalNC + // is the total number of counters; \p FNV is the pointer to the + // PGO function name var; \p FHash is the function hash. + void instrumentSelects(Function &Func, unsigned *Ind, unsigned TotalNC, + GlobalVariable *FNV, uint64_t FHash) { + Mode = VM_instrument; + CurCtrIdx = Ind; + TotalNumCtrs = TotalNC; + FuncHash = FHash; + FuncNameVar = FNV; + visit(Func); + } + + // Visit the IR stream and annotate all select instructions. + void annotateSelects(Function &Func, PGOUseFunc *UF, unsigned *Ind) { + Mode = VM_annotate; + UseFunc = UF; + CurCtrIdx = Ind; + visit(Func); + } + + void instrumentOneSelectInst(SelectInst &SI); + void annotateOneSelectInst(SelectInst &SI); + // Visit \p SI instruction and perform tasks according to visit mode. + void visitSelectInst(SelectInst &SI); + unsigned getNumOfSelectInsts() const { return NSIs; } +}; + class PGOInstrumentationGenLegacyPass : public ModulePass { public: static char ID; @@ -132,9 +208,7 @@ public: *PassRegistry::getPassRegistry()); } - const char *getPassName() const override { - return "PGOInstrumentationGenPass"; - } + StringRef getPassName() const override { return "PGOInstrumentationGenPass"; } private: bool runOnModule(Module &M) override; @@ -157,9 +231,7 @@ public: *PassRegistry::getPassRegistry()); } - const char *getPassName() const override { - return "PGOInstrumentationUsePass"; - } + StringRef getPassName() const override { return "PGOInstrumentationUsePass"; } private: std::string ProfileFileName; @@ -169,6 +241,7 @@ private: AU.addRequired<BlockFrequencyInfoWrapperPass>(); } }; + } // end anonymous namespace char PGOInstrumentationGenLegacyPass::ID = 0; @@ -238,8 +311,13 @@ template <class Edge, class BBInfo> class FuncPGOInstrumentation { private: Function &F; void computeCFGHash(); + void renameComdatFunction(); + // A map that stores the Comdat group in function F. + std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers; public: + std::vector<Instruction *> IndirectCallSites; + SelectInstVisitor SIVisitor; std::string FuncName; GlobalVariable *FuncNameVar; // CFG hash value for this function. @@ -255,18 +333,32 @@ public: // Return the auxiliary BB information. BBInfo &getBBInfo(const BasicBlock *BB) const { return MST.getBBInfo(BB); } + // Return the auxiliary BB information if available. + BBInfo *findBBInfo(const BasicBlock *BB) const { return MST.findBBInfo(BB); } + // Dump edges and BB information. void dumpInfo(std::string Str = "") const { MST.dumpEdges(dbgs(), Twine("Dump Function ") + FuncName + " Hash: " + Twine(FunctionHash) + "\t" + Str); } - FuncPGOInstrumentation(Function &Func, bool CreateGlobalVar = false, - BranchProbabilityInfo *BPI = nullptr, - BlockFrequencyInfo *BFI = nullptr) - : F(Func), FunctionHash(0), MST(F, BPI, BFI) { + FuncPGOInstrumentation( + Function &Func, + std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers, + bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr, + BlockFrequencyInfo *BFI = nullptr) + : F(Func), ComdatMembers(ComdatMembers), SIVisitor(Func), FunctionHash(0), + MST(F, BPI, BFI) { + + // This should be done before CFG hash computation. + SIVisitor.countSelects(Func); + NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts(); + IndirectCallSites = findIndirectCallSites(Func); + FuncName = getPGOFuncName(F); computeCFGHash(); + if (ComdatMembers.size()) + renameComdatFunction(); DEBUG(dumpInfo("after CFGMST")); NumOfPGOBB += MST.BBInfos.size(); @@ -281,6 +373,16 @@ public: if (CreateGlobalVar) FuncNameVar = createPGOFuncNameVar(F, FuncName); } + + // Return the number of profile counters needed for the function. + unsigned getNumCounters() { + unsigned NumCounters = 0; + for (auto &E : this->MST.AllEdges) { + if (!E->InMST && !E->Removed) + NumCounters++; + } + return NumCounters + SIVisitor.getNumOfSelectInsts(); + } }; // Compute Hash value for the CFG: the lower 32 bits are CRC32 of the index @@ -293,13 +395,90 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() { const TerminatorInst *TI = BB.getTerminator(); for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) { BasicBlock *Succ = TI->getSuccessor(I); - uint32_t Index = getBBInfo(Succ).Index; + auto BI = findBBInfo(Succ); + if (BI == nullptr) + continue; + uint32_t Index = BI->Index; for (int J = 0; J < 4; J++) Indexes.push_back((char)(Index >> (J * 8))); } } JC.update(Indexes); - FunctionHash = (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC(); + FunctionHash = (uint64_t)SIVisitor.getNumOfSelectInsts() << 56 | + (uint64_t)IndirectCallSites.size() << 48 | + (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC(); +} + +// Check if we can safely rename this Comdat function. +static bool canRenameComdat( + Function &F, + std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers) { + if (!DoComdatRenaming || !canRenameComdatFunc(F, true)) + return false; + + // FIXME: Current only handle those Comdat groups that only containing one + // function and function aliases. + // (1) For a Comdat group containing multiple functions, we need to have a + // unique postfix based on the hashes for each function. There is a + // non-trivial code refactoring to do this efficiently. + // (2) Variables can not be renamed, so we can not rename Comdat function in a + // group including global vars. + Comdat *C = F.getComdat(); + for (auto &&CM : make_range(ComdatMembers.equal_range(C))) { + if (dyn_cast<GlobalAlias>(CM.second)) + continue; + Function *FM = dyn_cast<Function>(CM.second); + if (FM != &F) + return false; + } + return true; +} + +// Append the CFGHash to the Comdat function name. +template <class Edge, class BBInfo> +void FuncPGOInstrumentation<Edge, BBInfo>::renameComdatFunction() { + if (!canRenameComdat(F, ComdatMembers)) + return; + std::string OrigName = F.getName().str(); + std::string NewFuncName = + Twine(F.getName() + "." + Twine(FunctionHash)).str(); + F.setName(Twine(NewFuncName)); + GlobalAlias::create(GlobalValue::WeakAnyLinkage, OrigName, &F); + FuncName = Twine(FuncName + "." + Twine(FunctionHash)).str(); + Comdat *NewComdat; + Module *M = F.getParent(); + // For AvailableExternallyLinkage functions, change the linkage to + // LinkOnceODR and put them into comdat. This is because after renaming, there + // is no backup external copy available for the function. + if (!F.hasComdat()) { + assert(F.getLinkage() == GlobalValue::AvailableExternallyLinkage); + NewComdat = M->getOrInsertComdat(StringRef(NewFuncName)); + F.setLinkage(GlobalValue::LinkOnceODRLinkage); + F.setComdat(NewComdat); + return; + } + + // This function belongs to a single function Comdat group. + Comdat *OrigComdat = F.getComdat(); + std::string NewComdatName = + Twine(OrigComdat->getName() + "." + Twine(FunctionHash)).str(); + NewComdat = M->getOrInsertComdat(StringRef(NewComdatName)); + NewComdat->setSelectionKind(OrigComdat->getSelectionKind()); + + for (auto &&CM : make_range(ComdatMembers.equal_range(OrigComdat))) { + if (GlobalAlias *GA = dyn_cast<GlobalAlias>(CM.second)) { + // For aliases, change the name directly. + assert(dyn_cast<Function>(GA->getAliasee()->stripPointerCasts()) == &F); + std::string OrigGAName = GA->getName().str(); + GA->setName(Twine(GA->getName() + "." + Twine(FunctionHash))); + GlobalAlias::create(GlobalValue::WeakAnyLinkage, OrigGAName, GA); + continue; + } + // Must be a function. + Function *CF = dyn_cast<Function>(CM.second); + assert(CF); + CF->setComdat(NewComdat); + } } // Given a CFG E to be instrumented, find which BB to place the instrumented @@ -340,15 +519,12 @@ BasicBlock *FuncPGOInstrumentation<Edge, BBInfo>::getInstrBB(Edge *E) { // Visit all edge and instrument the edges not in MST, and do value profiling. // Critical edges will be split. -static void instrumentOneFunc(Function &F, Module *M, - BranchProbabilityInfo *BPI, - BlockFrequencyInfo *BFI) { - unsigned NumCounters = 0; - FuncPGOInstrumentation<PGOEdge, BBInfo> FuncInfo(F, true, BPI, BFI); - for (auto &E : FuncInfo.MST.AllEdges) { - if (!E->InMST && !E->Removed) - NumCounters++; - } +static void instrumentOneFunc( + Function &F, Module *M, BranchProbabilityInfo *BPI, BlockFrequencyInfo *BFI, + std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers) { + FuncPGOInstrumentation<PGOEdge, BBInfo> FuncInfo(F, ComdatMembers, true, BPI, + BFI); + unsigned NumCounters = FuncInfo.getNumCounters(); uint32_t I = 0; Type *I8PtrTy = Type::getInt8PtrTy(M->getContext()); @@ -367,11 +543,16 @@ static void instrumentOneFunc(Function &F, Module *M, Builder.getInt32(I++)}); } + // Now instrument select instructions: + FuncInfo.SIVisitor.instrumentSelects(F, &I, NumCounters, FuncInfo.FuncNameVar, + FuncInfo.FunctionHash); + assert(I == NumCounters); + if (DisableValueProfiling) return; unsigned NumIndirectCallSites = 0; - for (auto &I : findIndirectCallSites(F)) { + for (auto &I : FuncInfo.IndirectCallSites) { CallSite CS(I); Value *Callee = CS.getCalledValue(); DEBUG(dbgs() << "Instrument one indirect call: CallSite Index = " @@ -456,10 +637,12 @@ static uint64_t sumEdgeCount(const ArrayRef<PGOUseEdge *> Edges) { class PGOUseFunc { public: - PGOUseFunc(Function &Func, Module *Modu, BranchProbabilityInfo *BPI = nullptr, + PGOUseFunc(Function &Func, Module *Modu, + std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers, + BranchProbabilityInfo *BPI = nullptr, BlockFrequencyInfo *BFI = nullptr) - : F(Func), M(Modu), FuncInfo(Func, false, BPI, BFI), - FreqAttr(FFA_Normal) {} + : F(Func), M(Modu), FuncInfo(Func, ComdatMembers, false, BPI, BFI), + CountPosition(0), ProfileCountSize(0), FreqAttr(FFA_Normal) {} // Read counts for the instrumented BB from profile. bool readCounters(IndexedInstrProfReader *PGOReader); @@ -479,24 +662,37 @@ public: // Return the function hotness from the profile. FuncFreqAttr getFuncFreqAttr() const { return FreqAttr; } + // Return the function hash. + uint64_t getFuncHash() const { return FuncInfo.FunctionHash; } // Return the profile record for this function; InstrProfRecord &getProfileRecord() { return ProfileRecord; } + // Return the auxiliary BB information. + UseBBInfo &getBBInfo(const BasicBlock *BB) const { + return FuncInfo.getBBInfo(BB); + } + + // Return the auxiliary BB information if available. + UseBBInfo *findBBInfo(const BasicBlock *BB) const { + return FuncInfo.findBBInfo(BB); + } + private: Function &F; Module *M; // This member stores the shared information with class PGOGenFunc. FuncPGOInstrumentation<PGOUseEdge, UseBBInfo> FuncInfo; - // Return the auxiliary BB information. - UseBBInfo &getBBInfo(const BasicBlock *BB) const { - return FuncInfo.getBBInfo(BB); - } - // The maximum count value in the profile. This is only used in PGO use // compilation. uint64_t ProgramMaxCount; + // Position of counter that remains to be read. + uint32_t CountPosition; + + // Total size of the profile count for this function. + uint32_t ProfileCountSize; + // ProfileRecord for this function. InstrProfRecord ProfileRecord; @@ -535,6 +731,7 @@ private: void PGOUseFunc::setInstrumentedCounts( const std::vector<uint64_t> &CountFromProfile) { + assert(FuncInfo.getNumCounters() == CountFromProfile.size()); // Use a worklist as we will update the vector during the iteration. std::vector<PGOUseEdge *> WorkList; for (auto &E : FuncInfo.MST.AllEdges) @@ -564,6 +761,8 @@ void PGOUseFunc::setInstrumentedCounts( NewEdge1.InMST = true; getBBInfo(InstrBB).setBBInfoCount(CountValue); } + ProfileCountSize = CountFromProfile.size(); + CountPosition = I; } // Set the count value for the unknown edge. There should be one and only one @@ -594,11 +793,15 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader) { bool SkipWarning = false; if (Err == instrprof_error::unknown_function) { NumOfPGOMissing++; - SkipWarning = NoPGOWarnMissing; + SkipWarning = !PGOWarnMissing; } else if (Err == instrprof_error::hash_mismatch || Err == instrprof_error::malformed) { NumOfPGOMismatch++; - SkipWarning = NoPGOWarnMismatch; + SkipWarning = + NoPGOWarnMismatch || + (NoPGOWarnMismatchComdat && + (F.hasComdat() || + F.getLinkage() == GlobalValue::AvailableExternallyLinkage)); } if (SkipWarning) @@ -663,27 +866,38 @@ void PGOUseFunc::populateCounters() { // For efficient traversal, it's better to start from the end as most // of the instrumented edges are at the end. for (auto &BB : reverse(F)) { - UseBBInfo &Count = getBBInfo(&BB); - if (!Count.CountValid) { - if (Count.UnknownCountOutEdge == 0) { - Count.CountValue = sumEdgeCount(Count.OutEdges); - Count.CountValid = true; + UseBBInfo *Count = findBBInfo(&BB); + if (Count == nullptr) + continue; + if (!Count->CountValid) { + if (Count->UnknownCountOutEdge == 0) { + Count->CountValue = sumEdgeCount(Count->OutEdges); + Count->CountValid = true; Changes = true; - } else if (Count.UnknownCountInEdge == 0) { - Count.CountValue = sumEdgeCount(Count.InEdges); - Count.CountValid = true; + } else if (Count->UnknownCountInEdge == 0) { + Count->CountValue = sumEdgeCount(Count->InEdges); + Count->CountValid = true; Changes = true; } } - if (Count.CountValid) { - if (Count.UnknownCountOutEdge == 1) { - uint64_t Total = Count.CountValue - sumEdgeCount(Count.OutEdges); - setEdgeCount(Count.OutEdges, Total); + if (Count->CountValid) { + if (Count->UnknownCountOutEdge == 1) { + uint64_t Total = 0; + uint64_t OutSum = sumEdgeCount(Count->OutEdges); + // If the one of the successor block can early terminate (no-return), + // we can end up with situation where out edge sum count is larger as + // the source BB's count is collected by a post-dominated block. + if (Count->CountValue > OutSum) + Total = Count->CountValue - OutSum; + setEdgeCount(Count->OutEdges, Total); Changes = true; } - if (Count.UnknownCountInEdge == 1) { - uint64_t Total = Count.CountValue - sumEdgeCount(Count.InEdges); - setEdgeCount(Count.InEdges, Total); + if (Count->UnknownCountInEdge == 1) { + uint64_t Total = 0; + uint64_t InSum = sumEdgeCount(Count->InEdges); + if (Count->CountValue > InSum) + Total = Count->CountValue - InSum; + setEdgeCount(Count->InEdges, Total); Changes = true; } } @@ -693,24 +907,50 @@ void PGOUseFunc::populateCounters() { DEBUG(dbgs() << "Populate counts in " << NumPasses << " passes.\n"); #ifndef NDEBUG // Assert every BB has a valid counter. - for (auto &BB : F) - assert(getBBInfo(&BB).CountValid && "BB count is not valid"); + for (auto &BB : F) { + auto BI = findBBInfo(&BB); + if (BI == nullptr) + continue; + assert(BI->CountValid && "BB count is not valid"); + } #endif uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue; F.setEntryCount(FuncEntryCount); uint64_t FuncMaxCount = FuncEntryCount; - for (auto &BB : F) - FuncMaxCount = std::max(FuncMaxCount, getBBInfo(&BB).CountValue); + for (auto &BB : F) { + auto BI = findBBInfo(&BB); + if (BI == nullptr) + continue; + FuncMaxCount = std::max(FuncMaxCount, BI->CountValue); + } markFunctionAttributes(FuncEntryCount, FuncMaxCount); + // Now annotate select instructions + FuncInfo.SIVisitor.annotateSelects(F, this, &CountPosition); + assert(CountPosition == ProfileCountSize); + DEBUG(FuncInfo.dumpInfo("after reading profile.")); } +static void setProfMetadata(Module *M, Instruction *TI, + ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) { + MDBuilder MDB(M->getContext()); + assert(MaxCount > 0 && "Bad max count"); + uint64_t Scale = calculateCountScale(MaxCount); + SmallVector<unsigned, 4> Weights; + for (const auto &ECI : EdgeCounts) + Weights.push_back(scaleBranchCount(ECI, Scale)); + + DEBUG(dbgs() << "Weight is: "; + for (const auto &W : Weights) { dbgs() << W << " "; } + dbgs() << "\n";); + TI->setMetadata(llvm::LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); +} + // Assign the scaled count values to the BB with multiple out edges. void PGOUseFunc::setBranchWeights() { // Generate MD_prof metadata for every branch instruction. DEBUG(dbgs() << "\nSetting branch weights.\n"); - MDBuilder MDB(M->getContext()); for (auto &BB : F) { TerminatorInst *TI = BB.getTerminator(); if (TI->getNumSuccessors() < 2) @@ -723,7 +963,7 @@ void PGOUseFunc::setBranchWeights() { // We have a non-zero Branch BB. const UseBBInfo &BBCountInfo = getBBInfo(&BB); unsigned Size = BBCountInfo.OutEdges.size(); - SmallVector<unsigned, 2> EdgeCounts(Size, 0); + SmallVector<uint64_t, 2> EdgeCounts(Size, 0); uint64_t MaxCount = 0; for (unsigned s = 0; s < Size; s++) { const PGOUseEdge *E = BBCountInfo.OutEdges[s]; @@ -737,20 +977,64 @@ void PGOUseFunc::setBranchWeights() { MaxCount = EdgeCount; EdgeCounts[SuccNum] = EdgeCount; } - assert(MaxCount > 0 && "Bad max count"); - uint64_t Scale = calculateCountScale(MaxCount); - SmallVector<unsigned, 4> Weights; - for (const auto &ECI : EdgeCounts) - Weights.push_back(scaleBranchCount(ECI, Scale)); - - TI->setMetadata(llvm::LLVMContext::MD_prof, - MDB.createBranchWeights(Weights)); - DEBUG(dbgs() << "Weight is: "; - for (const auto &W : Weights) { dbgs() << W << " "; } - dbgs() << "\n";); + setProfMetadata(M, TI, EdgeCounts, MaxCount); } } +void SelectInstVisitor::instrumentOneSelectInst(SelectInst &SI) { + Module *M = F.getParent(); + IRBuilder<> Builder(&SI); + Type *Int64Ty = Builder.getInt64Ty(); + Type *I8PtrTy = Builder.getInt8PtrTy(); + auto *Step = Builder.CreateZExt(SI.getCondition(), Int64Ty); + Builder.CreateCall( + Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment_step), + {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), + Builder.getInt64(FuncHash), + Builder.getInt32(TotalNumCtrs), Builder.getInt32(*CurCtrIdx), Step}); + ++(*CurCtrIdx); +} + +void SelectInstVisitor::annotateOneSelectInst(SelectInst &SI) { + std::vector<uint64_t> &CountFromProfile = UseFunc->getProfileRecord().Counts; + assert(*CurCtrIdx < CountFromProfile.size() && + "Out of bound access of counters"); + uint64_t SCounts[2]; + SCounts[0] = CountFromProfile[*CurCtrIdx]; // True count + ++(*CurCtrIdx); + uint64_t TotalCount = 0; + auto BI = UseFunc->findBBInfo(SI.getParent()); + if (BI != nullptr) + TotalCount = BI->CountValue; + // False Count + SCounts[1] = (TotalCount > SCounts[0] ? TotalCount - SCounts[0] : 0); + uint64_t MaxCount = std::max(SCounts[0], SCounts[1]); + if (MaxCount) + setProfMetadata(F.getParent(), &SI, SCounts, MaxCount); +} + +void SelectInstVisitor::visitSelectInst(SelectInst &SI) { + if (!PGOInstrSelect) + return; + // FIXME: do not handle this yet. + if (SI.getCondition()->getType()->isVectorTy()) + return; + + NSIs++; + switch (Mode) { + case VM_counting: + return; + case VM_instrument: + instrumentOneSelectInst(SI); + return; + case VM_annotate: + annotateOneSelectInst(SI); + return; + } + + llvm_unreachable("Unknown visiting mode"); +} + // Traverse all the indirect callsites and annotate the instructions. void PGOUseFunc::annotateIndirectCallSites() { if (DisableValueProfiling) @@ -760,7 +1044,7 @@ void PGOUseFunc::annotateIndirectCallSites() { createPGOFuncNameMetadata(F, FuncInfo.FuncName); unsigned IndirectCallSiteIndex = 0; - auto IndirectCallSites = findIndirectCallSites(F); + auto &IndirectCallSites = FuncInfo.IndirectCallSites; unsigned NumValueSites = ProfileRecord.getNumValueSites(IPVK_IndirectCallTarget); if (NumValueSites != IndirectCallSites.size()) { @@ -784,7 +1068,7 @@ void PGOUseFunc::annotateIndirectCallSites() { } } // end anonymous namespace -// Create a COMDAT variable IR_LEVEL_PROF_VARNAME to make the runtime +// Create a COMDAT variable INSTR_PROF_RAW_VERSION_VAR to make the runtime // aware this is an ir_level profile so it can set the version flag. static void createIRLevelProfileFlagVariable(Module &M) { Type *IntTy64 = Type::getInt64Ty(M.getContext()); @@ -792,26 +1076,47 @@ static void createIRLevelProfileFlagVariable(Module &M) { auto IRLevelVersionVariable = new GlobalVariable( M, IntTy64, true, GlobalVariable::ExternalLinkage, Constant::getIntegerValue(IntTy64, APInt(64, ProfileVersion)), - INSTR_PROF_QUOTE(IR_LEVEL_PROF_VERSION_VAR)); + INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR)); IRLevelVersionVariable->setVisibility(GlobalValue::DefaultVisibility); Triple TT(M.getTargetTriple()); if (!TT.supportsCOMDAT()) IRLevelVersionVariable->setLinkage(GlobalValue::WeakAnyLinkage); else IRLevelVersionVariable->setComdat(M.getOrInsertComdat( - StringRef(INSTR_PROF_QUOTE(IR_LEVEL_PROF_VERSION_VAR)))); + StringRef(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR)))); +} + +// Collect the set of members for each Comdat in module M and store +// in ComdatMembers. +static void collectComdatMembers( + Module &M, + std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers) { + if (!DoComdatRenaming) + return; + for (Function &F : M) + if (Comdat *C = F.getComdat()) + ComdatMembers.insert(std::make_pair(C, &F)); + for (GlobalVariable &GV : M.globals()) + if (Comdat *C = GV.getComdat()) + ComdatMembers.insert(std::make_pair(C, &GV)); + for (GlobalAlias &GA : M.aliases()) + if (Comdat *C = GA.getComdat()) + ComdatMembers.insert(std::make_pair(C, &GA)); } static bool InstrumentAllFunctions( Module &M, function_ref<BranchProbabilityInfo *(Function &)> LookupBPI, function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) { createIRLevelProfileFlagVariable(M); + std::unordered_multimap<Comdat *, GlobalValue *> ComdatMembers; + collectComdatMembers(M, ComdatMembers); + for (auto &F : M) { if (F.isDeclaration()) continue; auto *BPI = LookupBPI(F); auto *BFI = LookupBFI(F); - instrumentOneFunc(F, &M, BPI, BFI); + instrumentOneFunc(F, &M, BPI, BFI, ComdatMembers); } return true; } @@ -830,7 +1135,7 @@ bool PGOInstrumentationGenLegacyPass::runOnModule(Module &M) { } PreservedAnalyses PGOInstrumentationGen::run(Module &M, - AnalysisManager<Module> &AM) { + ModuleAnalysisManager &AM) { auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); auto LookupBPI = [&FAM](Function &F) { @@ -877,6 +1182,8 @@ static bool annotateAllFunctions( return false; } + std::unordered_multimap<Comdat *, GlobalValue *> ComdatMembers; + collectComdatMembers(M, ComdatMembers); std::vector<Function *> HotFunctions; std::vector<Function *> ColdFunctions; for (auto &F : M) { @@ -884,7 +1191,7 @@ static bool annotateAllFunctions( continue; auto *BPI = LookupBPI(F); auto *BFI = LookupBFI(F); - PGOUseFunc Func(F, &M, BPI, BFI); + PGOUseFunc Func(F, &M, ComdatMembers, BPI, BFI); if (!Func.readCounters(PGOReader.get())) continue; Func.populateCounters(); @@ -910,7 +1217,6 @@ static bool annotateAllFunctions( F->addFnAttr(llvm::Attribute::Cold); DEBUG(dbgs() << "Set cold attribute to function: " << F->getName() << "\n"); } - return true; } @@ -921,7 +1227,7 @@ PGOInstrumentationUse::PGOInstrumentationUse(std::string Filename) } PreservedAnalyses PGOInstrumentationUse::run(Module &M, - AnalysisManager<Module> &AM) { + ModuleAnalysisManager &AM) { auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); auto LookupBPI = [&FAM](Function &F) { diff --git a/contrib/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/contrib/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp index 7d40447..5b4b1fb 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -67,11 +67,23 @@ static const char *const SanCovTraceEnterName = static const char *const SanCovTraceBBName = "__sanitizer_cov_trace_basic_block"; static const char *const SanCovTracePCName = "__sanitizer_cov_trace_pc"; -static const char *const SanCovTraceCmpName = "__sanitizer_cov_trace_cmp"; +static const char *const SanCovTraceCmp1 = "__sanitizer_cov_trace_cmp1"; +static const char *const SanCovTraceCmp2 = "__sanitizer_cov_trace_cmp2"; +static const char *const SanCovTraceCmp4 = "__sanitizer_cov_trace_cmp4"; +static const char *const SanCovTraceCmp8 = "__sanitizer_cov_trace_cmp8"; +static const char *const SanCovTraceDiv4 = "__sanitizer_cov_trace_div4"; +static const char *const SanCovTraceDiv8 = "__sanitizer_cov_trace_div8"; +static const char *const SanCovTraceGep = "__sanitizer_cov_trace_gep"; static const char *const SanCovTraceSwitchName = "__sanitizer_cov_trace_switch"; static const char *const SanCovModuleCtorName = "sancov.module_ctor"; static const uint64_t SanCtorAndDtorPriority = 2; +static const char *const SanCovTracePCGuardSection = "__sancov_guards"; +static const char *const SanCovTracePCGuardName = + "__sanitizer_cov_trace_pc_guard"; +static const char *const SanCovTracePCGuardInitName = + "__sanitizer_cov_trace_pc_guard_init"; + static cl::opt<int> ClCoverageLevel( "sanitizer-coverage-level", cl::desc("Sanitizer Coverage. 0: none, 1: entry block, 2: all blocks, " @@ -95,11 +107,22 @@ static cl::opt<bool> ClExperimentalTracePC("sanitizer-coverage-trace-pc", cl::desc("Experimental pc tracing"), cl::Hidden, cl::init(false)); +static cl::opt<bool> ClTracePCGuard("sanitizer-coverage-trace-pc-guard", + cl::desc("pc tracing with a guard"), + cl::Hidden, cl::init(false)); + static cl::opt<bool> - ClExperimentalCMPTracing("sanitizer-coverage-experimental-trace-compares", - cl::desc("Experimental tracing of CMP and similar " - "instructions"), - cl::Hidden, cl::init(false)); + ClCMPTracing("sanitizer-coverage-trace-compares", + cl::desc("Tracing of CMP and similar instructions"), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> ClDIVTracing("sanitizer-coverage-trace-divs", + cl::desc("Tracing of DIV instructions"), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> ClGEPTracing("sanitizer-coverage-trace-geps", + cl::desc("Tracing of GEP instructions"), + cl::Hidden, cl::init(false)); static cl::opt<bool> ClPruneBlocks("sanitizer-coverage-prune-blocks", @@ -147,9 +170,12 @@ SanitizerCoverageOptions OverrideFromCL(SanitizerCoverageOptions Options) { Options.CoverageType = std::max(Options.CoverageType, CLOpts.CoverageType); Options.IndirectCalls |= CLOpts.IndirectCalls; Options.TraceBB |= ClExperimentalTracing; - Options.TraceCmp |= ClExperimentalCMPTracing; + Options.TraceCmp |= ClCMPTracing; + Options.TraceDiv |= ClDIVTracing; + Options.TraceGep |= ClGEPTracing; Options.Use8bitCounters |= ClUse8bitCounters; Options.TracePC |= ClExperimentalTracePC; + Options.TracePCGuard |= ClTracePCGuard; return Options; } @@ -163,7 +189,7 @@ public: bool runOnModule(Module &M) override; bool runOnFunction(Function &F); static char ID; // Pass identification, replacement for typeid - const char *getPassName() const override { return "SanitizerCoverageModule"; } + StringRef getPassName() const override { return "SanitizerCoverageModule"; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<DominatorTreeWrapperPass>(); @@ -174,11 +200,17 @@ private: void InjectCoverageForIndirectCalls(Function &F, ArrayRef<Instruction *> IndirCalls); void InjectTraceForCmp(Function &F, ArrayRef<Instruction *> CmpTraceTargets); + void InjectTraceForDiv(Function &F, + ArrayRef<BinaryOperator *> DivTraceTargets); + void InjectTraceForGep(Function &F, + ArrayRef<GetElementPtrInst *> GepTraceTargets); void InjectTraceForSwitch(Function &F, ArrayRef<Instruction *> SwitchTraceTargets); bool InjectCoverage(Function &F, ArrayRef<BasicBlock *> AllBlocks); + void CreateFunctionGuardArray(size_t NumGuards, Function &F); void SetNoSanitizeMetadata(Instruction *I); - void InjectCoverageAtBlock(Function &F, BasicBlock &BB, bool UseCalls); + void InjectCoverageAtBlock(Function &F, BasicBlock &BB, size_t Idx, + bool UseCalls); unsigned NumberOfInstrumentedBlocks() { return SanCovFunction->getNumUses() + SanCovWithCheckFunction->getNumUses() + SanCovTraceBB->getNumUses() + @@ -187,17 +219,21 @@ private: Function *SanCovFunction; Function *SanCovWithCheckFunction; Function *SanCovIndirCallFunction, *SanCovTracePCIndir; - Function *SanCovTraceEnter, *SanCovTraceBB, *SanCovTracePC; - Function *SanCovTraceCmpFunction; + Function *SanCovTraceEnter, *SanCovTraceBB, *SanCovTracePC, *SanCovTracePCGuard; + Function *SanCovTraceCmpFunction[4]; + Function *SanCovTraceDivFunction[2]; + Function *SanCovTraceGepFunction; Function *SanCovTraceSwitchFunction; InlineAsm *EmptyAsm; - Type *IntptrTy, *Int64Ty, *Int64PtrTy; + Type *IntptrTy, *IntptrPtrTy, *Int64Ty, *Int64PtrTy, *Int32Ty, *Int32PtrTy; Module *CurModule; LLVMContext *C; const DataLayout *DL; GlobalVariable *GuardArray; + GlobalVariable *FunctionGuardArray; // for trace-pc-guard. GlobalVariable *EightBitCounterArray; + bool HasSancovGuardsSection; SanitizerCoverageOptions Options; }; @@ -210,13 +246,16 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { C = &(M.getContext()); DL = &M.getDataLayout(); CurModule = &M; + HasSancovGuardsSection = false; IntptrTy = Type::getIntNTy(*C, DL->getPointerSizeInBits()); + IntptrPtrTy = PointerType::getUnqual(IntptrTy); Type *VoidTy = Type::getVoidTy(*C); IRBuilder<> IRB(*C); Type *Int8PtrTy = PointerType::getUnqual(IRB.getInt8Ty()); - Type *Int32PtrTy = PointerType::getUnqual(IRB.getInt32Ty()); Int64PtrTy = PointerType::getUnqual(IRB.getInt64Ty()); + Int32PtrTy = PointerType::getUnqual(IRB.getInt32Ty()); Int64Ty = IRB.getInt64Ty(); + Int32Ty = IRB.getInt32Ty(); SanCovFunction = checkSanitizerInterfaceFunction( M.getOrInsertFunction(SanCovName, VoidTy, Int32PtrTy, nullptr)); @@ -227,9 +266,28 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { SanCovIndirCallFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction( SanCovIndirCallName, VoidTy, IntptrTy, IntptrTy, nullptr)); - SanCovTraceCmpFunction = + SanCovTraceCmpFunction[0] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + SanCovTraceCmp1, VoidTy, IRB.getInt8Ty(), IRB.getInt8Ty(), nullptr)); + SanCovTraceCmpFunction[1] = checkSanitizerInterfaceFunction( + M.getOrInsertFunction(SanCovTraceCmp2, VoidTy, IRB.getInt16Ty(), + IRB.getInt16Ty(), nullptr)); + SanCovTraceCmpFunction[2] = checkSanitizerInterfaceFunction( + M.getOrInsertFunction(SanCovTraceCmp4, VoidTy, IRB.getInt32Ty(), + IRB.getInt32Ty(), nullptr)); + SanCovTraceCmpFunction[3] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - SanCovTraceCmpName, VoidTy, Int64Ty, Int64Ty, Int64Ty, nullptr)); + SanCovTraceCmp8, VoidTy, Int64Ty, Int64Ty, nullptr)); + + SanCovTraceDivFunction[0] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + SanCovTraceDiv4, VoidTy, IRB.getInt32Ty(), nullptr)); + SanCovTraceDivFunction[1] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + SanCovTraceDiv8, VoidTy, Int64Ty, nullptr)); + SanCovTraceGepFunction = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + SanCovTraceGep, VoidTy, IntptrTy, nullptr)); SanCovTraceSwitchFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction( SanCovTraceSwitchName, VoidTy, Int64Ty, Int64PtrTy, nullptr)); @@ -241,6 +299,8 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { SanCovTracePC = checkSanitizerInterfaceFunction( M.getOrInsertFunction(SanCovTracePCName, VoidTy, nullptr)); + SanCovTracePCGuard = checkSanitizerInterfaceFunction(M.getOrInsertFunction( + SanCovTracePCGuardName, VoidTy, Int32PtrTy, nullptr)); SanCovTraceEnter = checkSanitizerInterfaceFunction( M.getOrInsertFunction(SanCovTraceEnterName, VoidTy, Int32PtrTy, nullptr)); SanCovTraceBB = checkSanitizerInterfaceFunction( @@ -251,9 +311,10 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { Type *Int32Ty = IRB.getInt32Ty(); Type *Int8Ty = IRB.getInt8Ty(); - GuardArray = - new GlobalVariable(M, Int32Ty, false, GlobalValue::ExternalLinkage, - nullptr, "__sancov_gen_cov_tmp"); + if (!Options.TracePCGuard) + GuardArray = + new GlobalVariable(M, Int32Ty, false, GlobalValue::ExternalLinkage, + nullptr, "__sancov_gen_cov_tmp"); if (Options.Use8bitCounters) EightBitCounterArray = new GlobalVariable(M, Int8Ty, false, GlobalVariable::ExternalLinkage, @@ -264,17 +325,20 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { auto N = NumberOfInstrumentedBlocks(); - // Now we know how many elements we need. Create an array of guards - // with one extra element at the beginning for the size. - Type *Int32ArrayNTy = ArrayType::get(Int32Ty, N + 1); - GlobalVariable *RealGuardArray = new GlobalVariable( - M, Int32ArrayNTy, false, GlobalValue::PrivateLinkage, - Constant::getNullValue(Int32ArrayNTy), "__sancov_gen_cov"); - - // Replace the dummy array with the real one. - GuardArray->replaceAllUsesWith( - IRB.CreatePointerCast(RealGuardArray, Int32PtrTy)); - GuardArray->eraseFromParent(); + GlobalVariable *RealGuardArray = nullptr; + if (!Options.TracePCGuard) { + // Now we know how many elements we need. Create an array of guards + // with one extra element at the beginning for the size. + Type *Int32ArrayNTy = ArrayType::get(Int32Ty, N + 1); + RealGuardArray = new GlobalVariable( + M, Int32ArrayNTy, false, GlobalValue::PrivateLinkage, + Constant::getNullValue(Int32ArrayNTy), "__sancov_gen_cov"); + + // Replace the dummy array with the real one. + GuardArray->replaceAllUsesWith( + IRB.CreatePointerCast(RealGuardArray, Int32PtrTy)); + GuardArray->eraseFromParent(); + } GlobalVariable *RealEightBitCounterArray; if (Options.Use8bitCounters) { @@ -293,11 +357,30 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { // Create variable for module (compilation unit) name Constant *ModNameStrConst = ConstantDataArray::getString(M.getContext(), M.getName(), true); - GlobalVariable *ModuleName = - new GlobalVariable(M, ModNameStrConst->getType(), true, - GlobalValue::PrivateLinkage, ModNameStrConst); + GlobalVariable *ModuleName = new GlobalVariable( + M, ModNameStrConst->getType(), true, GlobalValue::PrivateLinkage, + ModNameStrConst, "__sancov_gen_modname"); + if (Options.TracePCGuard) { + if (HasSancovGuardsSection) { + Function *CtorFunc; + std::string SectionName(SanCovTracePCGuardSection); + GlobalVariable *Bounds[2]; + const char *Prefix[2] = {"__start_", "__stop_"}; + for (int i = 0; i < 2; i++) { + Bounds[i] = new GlobalVariable(M, Int32PtrTy, false, + GlobalVariable::ExternalLinkage, nullptr, + Prefix[i] + SectionName); + Bounds[i]->setVisibility(GlobalValue::HiddenVisibility); + } + std::tie(CtorFunc, std::ignore) = createSanitizerCtorAndInitFunctions( + M, SanCovModuleCtorName, SanCovTracePCGuardInitName, + {Int32PtrTy, Int32PtrTy}, + {IRB.CreatePointerCast(Bounds[0], Int32PtrTy), + IRB.CreatePointerCast(Bounds[1], Int32PtrTy)}); - if (!Options.TracePC) { + appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority); + } + } else if (!Options.TracePC) { Function *CtorFunc; std::tie(CtorFunc, std::ignore) = createSanitizerCtorAndInitFunctions( M, SanCovModuleCtorName, SanCovModuleInitName, @@ -344,6 +427,14 @@ static bool isFullPostDominator(const BasicBlock *BB, static bool shouldInstrumentBlock(const Function& F, const BasicBlock *BB, const DominatorTree *DT, const PostDominatorTree *PDT) { + // Don't insert coverage for unreachable blocks: we will never call + // __sanitizer_cov() for them, so counting them in + // NumberOfInstrumentedBlocks() might complicate calculation of code coverage + // percentage. Also, unreachable instructions frequently have no debug + // locations. + if (isa<UnreachableInst>(BB->getTerminator())) + return false; + if (!ClPruneBlocks || &F.getEntryBlock() == BB) return true; @@ -355,6 +446,13 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) { return false; if (F.getName().find(".module_ctor") != std::string::npos) return false; // Should not instrument sanitizer init functions. + if (F.getName().startswith("__sanitizer_")) + return false; // Don't instrument __sanitizer_* callbacks. + // Don't instrument MSVC CRT configuration helpers. They may run before normal + // initialization. + if (F.getName() == "__local_stdio_printf_options" || + F.getName() == "__local_stdio_scanf_options") + return false; // Don't instrument functions using SEH for now. Splitting basic blocks like // we do for coverage breaks WinEHPrepare. // FIXME: Remove this when SEH no longer uses landingpad pattern matching. @@ -367,6 +465,8 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) { SmallVector<BasicBlock *, 16> BlocksToInstrument; SmallVector<Instruction *, 8> CmpTraceTargets; SmallVector<Instruction *, 8> SwitchTraceTargets; + SmallVector<BinaryOperator *, 8> DivTraceTargets; + SmallVector<GetElementPtrInst *, 8> GepTraceTargets; const DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); @@ -388,28 +488,53 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) { if (isa<SwitchInst>(&Inst)) SwitchTraceTargets.push_back(&Inst); } - } + if (Options.TraceDiv) + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(&Inst)) + if (BO->getOpcode() == Instruction::SDiv || + BO->getOpcode() == Instruction::UDiv) + DivTraceTargets.push_back(BO); + if (Options.TraceGep) + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(&Inst)) + GepTraceTargets.push_back(GEP); + } } InjectCoverage(F, BlocksToInstrument); InjectCoverageForIndirectCalls(F, IndirCalls); InjectTraceForCmp(F, CmpTraceTargets); InjectTraceForSwitch(F, SwitchTraceTargets); + InjectTraceForDiv(F, DivTraceTargets); + InjectTraceForGep(F, GepTraceTargets); return true; } +void SanitizerCoverageModule::CreateFunctionGuardArray(size_t NumGuards, + Function &F) { + if (!Options.TracePCGuard) return; + HasSancovGuardsSection = true; + ArrayType *ArrayOfInt32Ty = ArrayType::get(Int32Ty, NumGuards); + FunctionGuardArray = new GlobalVariable( + *CurModule, ArrayOfInt32Ty, false, GlobalVariable::PrivateLinkage, + Constant::getNullValue(ArrayOfInt32Ty), "__sancov_gen_"); + if (auto Comdat = F.getComdat()) + FunctionGuardArray->setComdat(Comdat); + FunctionGuardArray->setSection(SanCovTracePCGuardSection); +} bool SanitizerCoverageModule::InjectCoverage(Function &F, ArrayRef<BasicBlock *> AllBlocks) { + if (AllBlocks.empty()) return false; switch (Options.CoverageType) { case SanitizerCoverageOptions::SCK_None: return false; case SanitizerCoverageOptions::SCK_Function: - InjectCoverageAtBlock(F, F.getEntryBlock(), false); + CreateFunctionGuardArray(1, F); + InjectCoverageAtBlock(F, F.getEntryBlock(), 0, false); return true; default: { bool UseCalls = ClCoverageBlockThreshold < AllBlocks.size(); - for (auto BB : AllBlocks) - InjectCoverageAtBlock(F, *BB, UseCalls); + CreateFunctionGuardArray(AllBlocks.size(), F); + for (size_t i = 0, N = AllBlocks.size(); i < N; i++) + InjectCoverageAtBlock(F, *AllBlocks[i], i, UseCalls); return true; } } @@ -439,7 +564,7 @@ void SanitizerCoverageModule::InjectCoverageForIndirectCalls( *F.getParent(), Ty, false, GlobalValue::PrivateLinkage, Constant::getNullValue(Ty), "__sancov_gen_callee_cache"); CalleeCache->setAlignment(CacheAlignment); - if (Options.TracePC) + if (Options.TracePC || Options.TracePCGuard) IRB.CreateCall(SanCovTracePCIndir, IRB.CreatePointerCast(Callee, IntptrTy)); else @@ -476,6 +601,11 @@ void SanitizerCoverageModule::InjectTraceForSwitch( C = ConstantExpr::getCast(CastInst::ZExt, It.getCaseValue(), Int64Ty); Initializers.push_back(C); } + std::sort(Initializers.begin() + 2, Initializers.end(), + [](const Constant *A, const Constant *B) { + return cast<ConstantInt>(A)->getLimitedValue() < + cast<ConstantInt>(B)->getLimitedValue(); + }); ArrayType *ArrayOfInt64Ty = ArrayType::get(Int64Ty, Initializers.size()); GlobalVariable *GV = new GlobalVariable( *CurModule, ArrayOfInt64Ty, false, GlobalVariable::InternalLinkage, @@ -487,6 +617,35 @@ void SanitizerCoverageModule::InjectTraceForSwitch( } } +void SanitizerCoverageModule::InjectTraceForDiv( + Function &, ArrayRef<BinaryOperator *> DivTraceTargets) { + for (auto BO : DivTraceTargets) { + IRBuilder<> IRB(BO); + Value *A1 = BO->getOperand(1); + if (isa<ConstantInt>(A1)) continue; + if (!A1->getType()->isIntegerTy()) + continue; + uint64_t TypeSize = DL->getTypeStoreSizeInBits(A1->getType()); + int CallbackIdx = TypeSize == 32 ? 0 : + TypeSize == 64 ? 1 : -1; + if (CallbackIdx < 0) continue; + auto Ty = Type::getIntNTy(*C, TypeSize); + IRB.CreateCall(SanCovTraceDivFunction[CallbackIdx], + {IRB.CreateIntCast(A1, Ty, true)}); + } +} + +void SanitizerCoverageModule::InjectTraceForGep( + Function &, ArrayRef<GetElementPtrInst *> GepTraceTargets) { + for (auto GEP : GepTraceTargets) { + IRBuilder<> IRB(GEP); + for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I) + if (!isa<ConstantInt>(*I) && (*I)->getType()->isIntegerTy()) + IRB.CreateCall(SanCovTraceGepFunction, + {IRB.CreateIntCast(*I, IntptrTy, true)}); + } +} + void SanitizerCoverageModule::InjectTraceForCmp( Function &, ArrayRef<Instruction *> CmpTraceTargets) { for (auto I : CmpTraceTargets) { @@ -497,12 +656,16 @@ void SanitizerCoverageModule::InjectTraceForCmp( if (!A0->getType()->isIntegerTy()) continue; uint64_t TypeSize = DL->getTypeStoreSizeInBits(A0->getType()); + int CallbackIdx = TypeSize == 8 ? 0 : + TypeSize == 16 ? 1 : + TypeSize == 32 ? 2 : + TypeSize == 64 ? 3 : -1; + if (CallbackIdx < 0) continue; // __sanitizer_cov_trace_cmp((type_size << 32) | predicate, A0, A1); + auto Ty = Type::getIntNTy(*C, TypeSize); IRB.CreateCall( - SanCovTraceCmpFunction, - {ConstantInt::get(Int64Ty, (TypeSize << 32) | ICMP->getPredicate()), - IRB.CreateIntCast(A0, Int64Ty, true), - IRB.CreateIntCast(A1, Int64Ty, true)}); + SanCovTraceCmpFunction[CallbackIdx], + {IRB.CreateIntCast(A0, Ty, true), IRB.CreateIntCast(A1, Ty, true)}); } } } @@ -513,16 +676,8 @@ void SanitizerCoverageModule::SetNoSanitizeMetadata(Instruction *I) { } void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, - bool UseCalls) { - // Don't insert coverage for unreachable blocks: we will never call - // __sanitizer_cov() for them, so counting them in - // NumberOfInstrumentedBlocks() might complicate calculation of code coverage - // percentage. Also, unreachable instructions frequently have no debug - // locations. - if (isa<UnreachableInst>(BB.getTerminator())) - return; + size_t Idx, bool UseCalls) { BasicBlock::iterator IP = BB.getFirstInsertionPt(); - bool IsEntryBB = &BB == &F.getEntryBlock(); DebugLoc EntryLoc; if (IsEntryBB) { @@ -538,32 +693,52 @@ void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, IRBuilder<> IRB(&*IP); IRB.SetCurrentDebugLocation(EntryLoc); - Value *GuardP = IRB.CreateAdd( - IRB.CreatePointerCast(GuardArray, IntptrTy), - ConstantInt::get(IntptrTy, (1 + NumberOfInstrumentedBlocks()) * 4)); - Type *Int32PtrTy = PointerType::getUnqual(IRB.getInt32Ty()); - GuardP = IRB.CreateIntToPtr(GuardP, Int32PtrTy); if (Options.TracePC) { IRB.CreateCall(SanCovTracePC); // gets the PC using GET_CALLER_PC. IRB.CreateCall(EmptyAsm, {}); // Avoids callback merge. - } else if (Options.TraceBB) { - IRB.CreateCall(IsEntryBB ? SanCovTraceEnter : SanCovTraceBB, GuardP); - } else if (UseCalls) { - IRB.CreateCall(SanCovWithCheckFunction, GuardP); - } else { - LoadInst *Load = IRB.CreateLoad(GuardP); - Load->setAtomic(AtomicOrdering::Monotonic); - Load->setAlignment(4); - SetNoSanitizeMetadata(Load); - Value *Cmp = - IRB.CreateICmpSGE(Constant::getNullValue(Load->getType()), Load); - Instruction *Ins = SplitBlockAndInsertIfThen( - Cmp, &*IP, false, MDBuilder(*C).createBranchWeights(1, 100000)); - IRB.SetInsertPoint(Ins); - IRB.SetCurrentDebugLocation(EntryLoc); - // __sanitizer_cov gets the PC of the instruction using GET_CALLER_PC. - IRB.CreateCall(SanCovFunction, GuardP); + } else if (Options.TracePCGuard) { + auto GuardPtr = IRB.CreateIntToPtr( + IRB.CreateAdd(IRB.CreatePointerCast(FunctionGuardArray, IntptrTy), + ConstantInt::get(IntptrTy, Idx * 4)), + Int32PtrTy); + if (!UseCalls) { + auto GuardLoad = IRB.CreateLoad(GuardPtr); + GuardLoad->setAtomic(AtomicOrdering::Monotonic); + GuardLoad->setAlignment(8); + SetNoSanitizeMetadata(GuardLoad); // Don't instrument with e.g. asan. + auto Cmp = IRB.CreateICmpNE( + GuardLoad, Constant::getNullValue(GuardLoad->getType())); + auto Ins = SplitBlockAndInsertIfThen( + Cmp, &*IP, false, MDBuilder(*C).createBranchWeights(1, 100000)); + IRB.SetInsertPoint(Ins); + IRB.SetCurrentDebugLocation(EntryLoc); + } + IRB.CreateCall(SanCovTracePCGuard, GuardPtr); IRB.CreateCall(EmptyAsm, {}); // Avoids callback merge. + } else { + Value *GuardP = IRB.CreateAdd( + IRB.CreatePointerCast(GuardArray, IntptrTy), + ConstantInt::get(IntptrTy, (1 + NumberOfInstrumentedBlocks()) * 4)); + GuardP = IRB.CreateIntToPtr(GuardP, Int32PtrTy); + if (Options.TraceBB) { + IRB.CreateCall(IsEntryBB ? SanCovTraceEnter : SanCovTraceBB, GuardP); + } else if (UseCalls) { + IRB.CreateCall(SanCovWithCheckFunction, GuardP); + } else { + LoadInst *Load = IRB.CreateLoad(GuardP); + Load->setAtomic(AtomicOrdering::Monotonic); + Load->setAlignment(4); + SetNoSanitizeMetadata(Load); + Value *Cmp = + IRB.CreateICmpSGE(Constant::getNullValue(Load->getType()), Load); + Instruction *Ins = SplitBlockAndInsertIfThen( + Cmp, &*IP, false, MDBuilder(*C).createBranchWeights(1, 100000)); + IRB.SetInsertPoint(Ins); + IRB.SetCurrentDebugLocation(EntryLoc); + // __sanitizer_cov gets the PC of the instruction using GET_CALLER_PC. + IRB.CreateCall(SanCovFunction, GuardP); + IRB.CreateCall(EmptyAsm, {}); // Avoids callback merge. + } } if (Options.Use8bitCounters) { diff --git a/contrib/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/contrib/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index 41041c7..52035c7 100644 --- a/contrib/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/contrib/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -43,6 +43,7 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/EscapeEnumerator.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -56,6 +57,10 @@ static cl::opt<bool> ClInstrumentMemoryAccesses( static cl::opt<bool> ClInstrumentFuncEntryExit( "tsan-instrument-func-entry-exit", cl::init(true), cl::desc("Instrument function entry and exit"), cl::Hidden); +static cl::opt<bool> ClHandleCxxExceptions( + "tsan-handle-cxx-exceptions", cl::init(true), + cl::desc("Handle C++ exceptions (insert cleanup blocks for unwinding)"), + cl::Hidden); static cl::opt<bool> ClInstrumentAtomics( "tsan-instrument-atomics", cl::init(true), cl::desc("Instrument atomics"), cl::Hidden); @@ -83,7 +88,7 @@ namespace { /// ThreadSanitizer: instrument the code in module to find races. struct ThreadSanitizer : public FunctionPass { ThreadSanitizer() : FunctionPass(ID) {} - const char *getPassName() const override; + StringRef getPassName() const override; void getAnalysisUsage(AnalysisUsage &AU) const override; bool runOnFunction(Function &F) override; bool doInitialization(Module &M) override; @@ -99,12 +104,15 @@ struct ThreadSanitizer : public FunctionPass { const DataLayout &DL); bool addrPointsToConstantData(Value *Addr); int getMemoryAccessFuncIndex(Value *Addr, const DataLayout &DL); + void InsertRuntimeIgnores(Function &F); Type *IntptrTy; IntegerType *OrdTy; // Callbacks to run-time library are computed in doInitialization. Function *TsanFuncEntry; Function *TsanFuncExit; + Function *TsanIgnoreBegin; + Function *TsanIgnoreEnd; // Accesses sizes are powers of two: 1, 2, 4, 8, 16. static const size_t kNumberOfAccessSizes = 5; Function *TsanRead[kNumberOfAccessSizes]; @@ -135,9 +143,7 @@ INITIALIZE_PASS_END( "ThreadSanitizer: detects data races.", false, false) -const char *ThreadSanitizer::getPassName() const { - return "ThreadSanitizer"; -} +StringRef ThreadSanitizer::getPassName() const { return "ThreadSanitizer"; } void ThreadSanitizer::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<TargetLibraryInfoWrapperPass>(); @@ -149,11 +155,17 @@ FunctionPass *llvm::createThreadSanitizerPass() { void ThreadSanitizer::initializeCallbacks(Module &M) { IRBuilder<> IRB(M.getContext()); + AttributeSet Attr; + Attr = Attr.addAttribute(M.getContext(), AttributeSet::FunctionIndex, Attribute::NoUnwind); // Initialize the callbacks. TsanFuncEntry = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_func_entry", IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + "__tsan_func_entry", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); TsanFuncExit = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("__tsan_func_exit", IRB.getVoidTy(), nullptr)); + M.getOrInsertFunction("__tsan_func_exit", Attr, IRB.getVoidTy(), nullptr)); + TsanIgnoreBegin = checkSanitizerInterfaceFunction(M.getOrInsertFunction( + "__tsan_ignore_thread_begin", Attr, IRB.getVoidTy(), nullptr)); + TsanIgnoreEnd = checkSanitizerInterfaceFunction(M.getOrInsertFunction( + "__tsan_ignore_thread_end", Attr, IRB.getVoidTy(), nullptr)); OrdTy = IRB.getInt32Ty(); for (size_t i = 0; i < kNumberOfAccessSizes; ++i) { const unsigned ByteSize = 1U << i; @@ -162,31 +174,31 @@ void ThreadSanitizer::initializeCallbacks(Module &M) { std::string BitSizeStr = utostr(BitSize); SmallString<32> ReadName("__tsan_read" + ByteSizeStr); TsanRead[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - ReadName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + ReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); SmallString<32> WriteName("__tsan_write" + ByteSizeStr); TsanWrite[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - WriteName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + WriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); SmallString<64> UnalignedReadName("__tsan_unaligned_read" + ByteSizeStr); TsanUnalignedRead[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - UnalignedReadName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + UnalignedReadName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); SmallString<64> UnalignedWriteName("__tsan_unaligned_write" + ByteSizeStr); TsanUnalignedWrite[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - UnalignedWriteName, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + UnalignedWriteName, Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); Type *Ty = Type::getIntNTy(M.getContext(), BitSize); Type *PtrTy = Ty->getPointerTo(); SmallString<32> AtomicLoadName("__tsan_atomic" + BitSizeStr + "_load"); TsanAtomicLoad[i] = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(AtomicLoadName, Ty, PtrTy, OrdTy, nullptr)); + M.getOrInsertFunction(AtomicLoadName, Attr, Ty, PtrTy, OrdTy, nullptr)); SmallString<32> AtomicStoreName("__tsan_atomic" + BitSizeStr + "_store"); TsanAtomicStore[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - AtomicStoreName, IRB.getVoidTy(), PtrTy, Ty, OrdTy, nullptr)); + AtomicStoreName, Attr, IRB.getVoidTy(), PtrTy, Ty, OrdTy, nullptr)); for (int op = AtomicRMWInst::FIRST_BINOP; op <= AtomicRMWInst::LAST_BINOP; ++op) { @@ -210,32 +222,32 @@ void ThreadSanitizer::initializeCallbacks(Module &M) { continue; SmallString<32> RMWName("__tsan_atomic" + itostr(BitSize) + NamePart); TsanAtomicRMW[op][i] = checkSanitizerInterfaceFunction( - M.getOrInsertFunction(RMWName, Ty, PtrTy, Ty, OrdTy, nullptr)); + M.getOrInsertFunction(RMWName, Attr, Ty, PtrTy, Ty, OrdTy, nullptr)); } SmallString<32> AtomicCASName("__tsan_atomic" + BitSizeStr + "_compare_exchange_val"); TsanAtomicCAS[i] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - AtomicCASName, Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, nullptr)); + AtomicCASName, Attr, Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, nullptr)); } TsanVptrUpdate = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("__tsan_vptr_update", IRB.getVoidTy(), + M.getOrInsertFunction("__tsan_vptr_update", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), nullptr)); TsanVptrLoad = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_vptr_read", IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); + "__tsan_vptr_read", Attr, IRB.getVoidTy(), IRB.getInt8PtrTy(), nullptr)); TsanAtomicThreadFence = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_atomic_thread_fence", IRB.getVoidTy(), OrdTy, nullptr)); + "__tsan_atomic_thread_fence", Attr, IRB.getVoidTy(), OrdTy, nullptr)); TsanAtomicSignalFence = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__tsan_atomic_signal_fence", IRB.getVoidTy(), OrdTy, nullptr)); + "__tsan_atomic_signal_fence", Attr, IRB.getVoidTy(), OrdTy, nullptr)); MemmoveFn = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("memmove", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), + M.getOrInsertFunction("memmove", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy, nullptr)); MemcpyFn = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("memcpy", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), + M.getOrInsertFunction("memcpy", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy, nullptr)); MemsetFn = checkSanitizerInterfaceFunction( - M.getOrInsertFunction("memset", IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), + M.getOrInsertFunction("memset", Attr, IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy, nullptr)); } @@ -378,13 +390,21 @@ static bool isAtomic(Instruction *I) { return false; } +void ThreadSanitizer::InsertRuntimeIgnores(Function &F) { + IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI()); + IRB.CreateCall(TsanIgnoreBegin); + EscapeEnumerator EE(F, "tsan_ignore_cleanup", ClHandleCxxExceptions); + while (IRBuilder<> *AtExit = EE.Next()) { + AtExit->CreateCall(TsanIgnoreEnd); + } +} + bool ThreadSanitizer::runOnFunction(Function &F) { // This is required to prevent instrumenting call to __tsan_init from within // the module constructor. if (&F == TsanCtorFunction) return false; initializeCallbacks(*F.getParent()); - SmallVector<Instruction*, 8> RetVec; SmallVector<Instruction*, 8> AllLoadsAndStores; SmallVector<Instruction*, 8> LocalLoadsAndStores; SmallVector<Instruction*, 8> AtomicAccesses; @@ -403,8 +423,6 @@ bool ThreadSanitizer::runOnFunction(Function &F) { AtomicAccesses.push_back(&Inst); else if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst)) LocalLoadsAndStores.push_back(&Inst); - else if (isa<ReturnInst>(Inst)) - RetVec.push_back(&Inst); else if (isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) { if (CallInst *CI = dyn_cast<CallInst>(&Inst)) maybeMarkSanitizerLibraryCallNoBuiltin(CI, TLI); @@ -440,6 +458,12 @@ bool ThreadSanitizer::runOnFunction(Function &F) { Res |= instrumentMemIntrinsic(Inst); } + if (F.hasFnAttribute("sanitize_thread_no_checking_at_run_time")) { + assert(!F.hasFnAttribute(Attribute::SanitizeThread)); + if (HasCalls) + InsertRuntimeIgnores(F); + } + // Instrument function entry/exit points if there were instrumented accesses. if ((Res || HasCalls) && ClInstrumentFuncEntryExit) { IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI()); @@ -447,9 +471,10 @@ bool ThreadSanitizer::runOnFunction(Function &F) { Intrinsic::getDeclaration(F.getParent(), Intrinsic::returnaddress), IRB.getInt32(0)); IRB.CreateCall(TsanFuncEntry, ReturnAddress); - for (auto RetInst : RetVec) { - IRBuilder<> IRBRet(RetInst); - IRBRet.CreateCall(TsanFuncExit, {}); + + EscapeEnumerator EE(F, "tsan_cleanup", ClHandleCxxExceptions); + while (IRBuilder<> *AtExit = EE.Next()) { + AtExit->CreateCall(TsanFuncExit, {}); } Res = true; } @@ -463,6 +488,13 @@ bool ThreadSanitizer::instrumentLoadOrStore(Instruction *I, Value *Addr = IsWrite ? cast<StoreInst>(I)->getPointerOperand() : cast<LoadInst>(I)->getPointerOperand(); + + // swifterror memory addresses are mem2reg promoted by instruction selection. + // As such they cannot have regular uses like an instrumentation function and + // it makes no sense to track them as memory. + if (Addr->isSwiftError()) + return false; + int Idx = getMemoryAccessFuncIndex(Addr, DL); if (Idx < 0) return false; @@ -511,7 +543,7 @@ static ConstantInt *createOrdering(IRBuilder<> *IRB, AtomicOrdering ord) { switch (ord) { case AtomicOrdering::NotAtomic: llvm_unreachable("unexpected atomic ordering!"); - case AtomicOrdering::Unordered: // Fall-through. + case AtomicOrdering::Unordered: LLVM_FALLTHROUGH; case AtomicOrdering::Monotonic: v = 0; break; // Not specified yet: // case AtomicOrdering::Consume: v = 1; break; @@ -551,11 +583,6 @@ bool ThreadSanitizer::instrumentMemIntrinsic(Instruction *I) { return false; } -static Value *createIntOrPtrToIntCast(Value *V, Type* Ty, IRBuilder<> &IRB) { - return isa<PointerType>(V->getType()) ? - IRB.CreatePtrToInt(V, Ty) : IRB.CreateIntCast(V, Ty, false); -} - // Both llvm and ThreadSanitizer atomic operations are based on C++11/C1x // standards. For background see C++11 standard. A slightly older, publicly // available draft of the standard (not entirely up-to-date, but close enough @@ -578,15 +605,9 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), createOrdering(&IRB, LI->getOrdering())}; Type *OrigTy = cast<PointerType>(Addr->getType())->getElementType(); - if (Ty == OrigTy) { - Instruction *C = CallInst::Create(TsanAtomicLoad[Idx], Args); - ReplaceInstWithInst(I, C); - } else { - // We are loading a pointer, so we need to cast the return value. - Value *C = IRB.CreateCall(TsanAtomicLoad[Idx], Args); - Instruction *Cast = CastInst::Create(Instruction::IntToPtr, C, OrigTy); - ReplaceInstWithInst(I, Cast); - } + Value *C = IRB.CreateCall(TsanAtomicLoad[Idx], Args); + Value *Cast = IRB.CreateBitOrPointerCast(C, OrigTy); + I->replaceAllUsesWith(Cast); } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) { Value *Addr = SI->getPointerOperand(); int Idx = getMemoryAccessFuncIndex(Addr, DL); @@ -597,7 +618,7 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize); Type *PtrTy = Ty->getPointerTo(); Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), - createIntOrPtrToIntCast(SI->getValueOperand(), Ty, IRB), + IRB.CreateBitOrPointerCast(SI->getValueOperand(), Ty), createOrdering(&IRB, SI->getOrdering())}; CallInst *C = CallInst::Create(TsanAtomicStore[Idx], Args); ReplaceInstWithInst(I, C); @@ -628,9 +649,9 @@ bool ThreadSanitizer::instrumentAtomic(Instruction *I, const DataLayout &DL) { Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize); Type *PtrTy = Ty->getPointerTo(); Value *CmpOperand = - createIntOrPtrToIntCast(CASI->getCompareOperand(), Ty, IRB); + IRB.CreateBitOrPointerCast(CASI->getCompareOperand(), Ty); Value *NewOperand = - createIntOrPtrToIntCast(CASI->getNewValOperand(), Ty, IRB); + IRB.CreateBitOrPointerCast(CASI->getNewValOperand(), Ty); Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), CmpOperand, NewOperand, diff --git a/contrib/llvm/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h b/contrib/llvm/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h index d4fef10..c748272 100644 --- a/contrib/llvm/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h +++ b/contrib/llvm/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h @@ -121,8 +121,7 @@ private: /// Declaration for objc_retainAutoreleaseReturnValue(). Constant *RetainAutoreleaseRV; - Constant *getVoidRetI8XEntryPoint(Constant *&Decl, - const char *Name) { + Constant *getVoidRetI8XEntryPoint(Constant *&Decl, StringRef Name) { if (Decl) return Decl; @@ -136,8 +135,7 @@ private: return Decl = TheModule->getOrInsertFunction(Name, Fty, Attr); } - Constant *getI8XRetI8XEntryPoint(Constant *& Decl, - const char *Name, + Constant *getI8XRetI8XEntryPoint(Constant *&Decl, StringRef Name, bool NoUnwind = false) { if (Decl) return Decl; @@ -155,8 +153,7 @@ private: return Decl = TheModule->getOrInsertFunction(Name, Fty, Attr); } - Constant *getI8XRetI8XXI8XEntryPoint(Constant *&Decl, - const char *Name) { + Constant *getI8XRetI8XXI8XEntryPoint(Constant *&Decl, StringRef Name) { if (Decl) return Decl; diff --git a/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp index 11e2d03e..23c1f59 100644 --- a/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp +++ b/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCContract.cpp @@ -423,7 +423,7 @@ bool ObjCARCContract::tryToPeepholeInstruction( if (!optimizeRetainCall(F, Inst)) return false; // If we succeed in our optimization, fall through. - // FALLTHROUGH + LLVM_FALLTHROUGH; case ARCInstKind::RetainRV: case ARCInstKind::ClaimRV: { // If we're compiling for a target which needs a special inline-asm @@ -547,13 +547,13 @@ bool ObjCARCContract::runOnFunction(Function &F) { // Don't use GetArgRCIdentityRoot because we don't want to look through bitcasts // and such; to do the replacement, the argument must have type i8*. - Value *Arg = cast<CallInst>(Inst)->getArgOperand(0); - // TODO: Change this to a do-while. - for (;;) { + // Function for replacing uses of Arg dominated by Inst. + auto ReplaceArgUses = [Inst, this](Value *Arg) { // If we're compiling bugpointed code, don't get in trouble. if (!isa<Instruction>(Arg) && !isa<Argument>(Arg)) - break; + return; + // Look through the uses of the pointer. for (Value::use_iterator UI = Arg->use_begin(), UE = Arg->use_end(); UI != UE; ) { @@ -598,6 +598,15 @@ bool ObjCARCContract::runOnFunction(Function &F) { } } } + }; + + + Value *Arg = cast<CallInst>(Inst)->getArgOperand(0); + Value *OrigArg = Arg; + + // TODO: Change this to a do-while. + for (;;) { + ReplaceArgUses(Arg); // If Arg is a no-op casted pointer, strip one level of casts and iterate. if (const BitCastInst *BI = dyn_cast<BitCastInst>(Arg)) @@ -611,6 +620,24 @@ bool ObjCARCContract::runOnFunction(Function &F) { else break; } + + // Replace bitcast users of Arg that are dominated by Inst. + SmallVector<BitCastInst *, 2> BitCastUsers; + + // Add all bitcast users of the function argument first. + for (User *U : OrigArg->users()) + if (auto *BC = dyn_cast<BitCastInst>(U)) + BitCastUsers.push_back(BC); + + // Replace the bitcasts with the call return. Iterate until list is empty. + while (!BitCastUsers.empty()) { + auto *BC = BitCastUsers.pop_back_val(); + for (User *U : BC->users()) + if (auto *B = dyn_cast<BitCastInst>(U)) + BitCastUsers.push_back(B); + + ReplaceArgUses(BC); + } } // If this function has no escaping allocas or suspicious vararg usage, diff --git a/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp index a6907b5..136d54a 100644 --- a/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp +++ b/contrib/llvm/lib/Transforms/ObjCARC/ObjCARCOpts.cpp @@ -53,6 +53,11 @@ using namespace llvm::objcarc; /// \brief This is similar to GetRCIdentityRoot but it stops as soon /// as it finds a value with multiple uses. static const Value *FindSingleUseIdentifiedObject(const Value *Arg) { + // ConstantData (like ConstantPointerNull and UndefValue) is used across + // modules. It's never a single-use value. + if (isa<ConstantData>(Arg)) + return nullptr; + if (Arg->hasOneUse()) { if (const BitCastInst *BC = dyn_cast<BitCastInst>(Arg)) return FindSingleUseIdentifiedObject(BC->getOperand(0)); @@ -644,6 +649,12 @@ void ObjCARCOpt::OptimizeAutoreleaseRVCall(Function &F, ARCInstKind &Class) { // Check for a return of the pointer value. const Value *Ptr = GetArgRCIdentityRoot(AutoreleaseRV); + + // If the argument is ConstantPointerNull or UndefValue, its other users + // aren't actually interesting to look at. + if (isa<ConstantData>(Ptr)) + return; + SmallVector<const Value *, 2> Users; Users.push_back(Ptr); do { @@ -2075,12 +2086,11 @@ void ObjCARCOpt::OptimizeReturns(Function &F) { SmallPtrSet<const BasicBlock *, 4> Visited; for (BasicBlock &BB: F) { ReturnInst *Ret = dyn_cast<ReturnInst>(&BB.back()); - - DEBUG(dbgs() << "Visiting: " << *Ret << "\n"); - if (!Ret) continue; + DEBUG(dbgs() << "Visiting: " << *Ret << "\n"); + const Value *Arg = GetRCIdentityRoot(Ret->getOperand(0)); // Look for an ``autorelease'' instruction that is a predecessor of Ret and diff --git a/contrib/llvm/lib/Transforms/ObjCARC/PtrState.cpp b/contrib/llvm/lib/Transforms/ObjCARC/PtrState.cpp index df64fa3..a5afc8a 100644 --- a/contrib/llvm/lib/Transforms/ObjCARC/PtrState.cpp +++ b/contrib/llvm/lib/Transforms/ObjCARC/PtrState.cpp @@ -201,7 +201,7 @@ bool BottomUpPtrState::MatchWithRetain() { // imprecise release, clear our reverse insertion points. if (OldSeq != S_Use || IsTrackingImpreciseReleases()) ClearReverseInsertPts(); - // FALL THROUGH + LLVM_FALLTHROUGH; case S_CanRelease: return true; case S_None: @@ -332,7 +332,7 @@ bool TopDownPtrState::MatchWithRelease(ARCMDKindCache &Cache, case S_CanRelease: if (OldSeq == S_Retain || ReleaseMetadata != nullptr) ClearReverseInsertPts(); - // FALL THROUGH + LLVM_FALLTHROUGH; case S_Use: SetReleaseMetadata(ReleaseMetadata); SetTailCallRelease(cast<CallInst>(Release)->isTailCall()); diff --git a/contrib/llvm/lib/Transforms/Scalar/ADCE.cpp b/contrib/llvm/lib/Transforms/Scalar/ADCE.cpp index 0eed024..adc903c 100644 --- a/contrib/llvm/lib/Transforms/Scalar/ADCE.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/ADCE.cpp @@ -15,14 +15,19 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/ADCE.h" + #include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/IteratedDominanceFrontier.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" @@ -34,9 +39,372 @@ using namespace llvm; #define DEBUG_TYPE "adce" STATISTIC(NumRemoved, "Number of instructions removed"); +STATISTIC(NumBranchesRemoved, "Number of branch instructions removed"); + +// This is a tempoary option until we change the interface +// to this pass based on optimization level. +static cl::opt<bool> RemoveControlFlowFlag("adce-remove-control-flow", + cl::init(true), cl::Hidden); + +// This option enables removing of may-be-infinite loops which have no other +// effect. +static cl::opt<bool> RemoveLoops("adce-remove-loops", cl::init(false), + cl::Hidden); + +namespace { +/// Information about Instructions +struct InstInfoType { + /// True if the associated instruction is live. + bool Live = false; + /// Quick access to information for block containing associated Instruction. + struct BlockInfoType *Block = nullptr; +}; + +/// Information about basic blocks relevant to dead code elimination. +struct BlockInfoType { + /// True when this block contains a live instructions. + bool Live = false; + /// True when this block ends in an unconditional branch. + bool UnconditionalBranch = false; + /// True when this block is known to have live PHI nodes. + bool HasLivePhiNodes = false; + /// Control dependence sources need to be live for this block. + bool CFLive = false; + + /// Quick access to the LiveInfo for the terminator, + /// holds the value &InstInfo[Terminator] + InstInfoType *TerminatorLiveInfo = nullptr; + + bool terminatorIsLive() const { return TerminatorLiveInfo->Live; } + + /// Corresponding BasicBlock. + BasicBlock *BB = nullptr; + + /// Cache of BB->getTerminator(). + TerminatorInst *Terminator = nullptr; + + /// Post-order numbering of reverse control flow graph. + unsigned PostOrder; +}; + +class AggressiveDeadCodeElimination { + Function &F; + PostDominatorTree &PDT; + + /// Mapping of blocks to associated information, an element in BlockInfoVec. + DenseMap<BasicBlock *, BlockInfoType> BlockInfo; + bool isLive(BasicBlock *BB) { return BlockInfo[BB].Live; } + + /// Mapping of instructions to associated information. + DenseMap<Instruction *, InstInfoType> InstInfo; + bool isLive(Instruction *I) { return InstInfo[I].Live; } + + /// Instructions known to be live where we need to mark + /// reaching definitions as live. + SmallVector<Instruction *, 128> Worklist; + /// Debug info scopes around a live instruction. + SmallPtrSet<const Metadata *, 32> AliveScopes; + + /// Set of blocks with not known to have live terminators. + SmallPtrSet<BasicBlock *, 16> BlocksWithDeadTerminators; + + /// The set of blocks which we have determined whose control + /// dependence sources must be live and which have not had + /// those dependences analyized. + SmallPtrSet<BasicBlock *, 16> NewLiveBlocks; + + /// Set up auxiliary data structures for Instructions and BasicBlocks and + /// initialize the Worklist to the set of must-be-live Instruscions. + void initialize(); + /// Return true for operations which are always treated as live. + bool isAlwaysLive(Instruction &I); + /// Return true for instrumentation instructions for value profiling. + bool isInstrumentsConstant(Instruction &I); + + /// Propagate liveness to reaching definitions. + void markLiveInstructions(); + /// Mark an instruction as live. + void markLive(Instruction *I); + /// Mark a block as live. + void markLive(BlockInfoType &BB); + void markLive(BasicBlock *BB) { markLive(BlockInfo[BB]); } + + /// Mark terminators of control predecessors of a PHI node live. + void markPhiLive(PHINode *PN); + + /// Record the Debug Scopes which surround live debug information. + void collectLiveScopes(const DILocalScope &LS); + void collectLiveScopes(const DILocation &DL); + + /// Analyze dead branches to find those whose branches are the sources + /// of control dependences impacting a live block. Those branches are + /// marked live. + void markLiveBranchesFromControlDependences(); + + /// Remove instructions not marked live, return if any any instruction + /// was removed. + bool removeDeadInstructions(); + + /// Identify connected sections of the control flow grap which have + /// dead terminators and rewrite the control flow graph to remove them. + void updateDeadRegions(); + + /// Set the BlockInfo::PostOrder field based on a post-order + /// numbering of the reverse control flow graph. + void computeReversePostOrder(); + + /// Make the terminator of this block an unconditional branch to \p Target. + void makeUnconditional(BasicBlock *BB, BasicBlock *Target); + +public: + AggressiveDeadCodeElimination(Function &F, PostDominatorTree &PDT) + : F(F), PDT(PDT) {} + bool performDeadCodeElimination(); +}; +} + +bool AggressiveDeadCodeElimination::performDeadCodeElimination() { + initialize(); + markLiveInstructions(); + return removeDeadInstructions(); +} + +static bool isUnconditionalBranch(TerminatorInst *Term) { + auto *BR = dyn_cast<BranchInst>(Term); + return BR && BR->isUnconditional(); +} + +void AggressiveDeadCodeElimination::initialize() { + + auto NumBlocks = F.size(); + + // We will have an entry in the map for each block so we grow the + // structure to twice that size to keep the load factor low in the hash table. + BlockInfo.reserve(NumBlocks); + size_t NumInsts = 0; + + // Iterate over blocks and initialize BlockInfoVec entries, count + // instructions to size the InstInfo hash table. + for (auto &BB : F) { + NumInsts += BB.size(); + auto &Info = BlockInfo[&BB]; + Info.BB = &BB; + Info.Terminator = BB.getTerminator(); + Info.UnconditionalBranch = isUnconditionalBranch(Info.Terminator); + } + + // Initialize instruction map and set pointers to block info. + InstInfo.reserve(NumInsts); + for (auto &BBInfo : BlockInfo) + for (Instruction &I : *BBInfo.second.BB) + InstInfo[&I].Block = &BBInfo.second; + + // Since BlockInfoVec holds pointers into InstInfo and vice-versa, we may not + // add any more elements to either after this point. + for (auto &BBInfo : BlockInfo) + BBInfo.second.TerminatorLiveInfo = &InstInfo[BBInfo.second.Terminator]; + + // Collect the set of "root" instructions that are known live. + for (Instruction &I : instructions(F)) + if (isAlwaysLive(I)) + markLive(&I); + + if (!RemoveControlFlowFlag) + return; + + if (!RemoveLoops) { + // This stores state for the depth-first iterator. In addition + // to recording which nodes have been visited we also record whether + // a node is currently on the "stack" of active ancestors of the current + // node. + typedef DenseMap<BasicBlock *, bool> StatusMap ; + class DFState : public StatusMap { + public: + std::pair<StatusMap::iterator, bool> insert(BasicBlock *BB) { + return StatusMap::insert(std::make_pair(BB, true)); + } + + // Invoked after we have visited all children of a node. + void completed(BasicBlock *BB) { (*this)[BB] = false; } + + // Return true if \p BB is currently on the active stack + // of ancestors. + bool onStack(BasicBlock *BB) { + auto Iter = find(BB); + return Iter != end() && Iter->second; + } + } State; + + State.reserve(F.size()); + // Iterate over blocks in depth-first pre-order and + // treat all edges to a block already seen as loop back edges + // and mark the branch live it if there is a back edge. + for (auto *BB: depth_first_ext(&F.getEntryBlock(), State)) { + TerminatorInst *Term = BB->getTerminator(); + if (isLive(Term)) + continue; + + for (auto *Succ : successors(BB)) + if (State.onStack(Succ)) { + // back edge.... + markLive(Term); + break; + } + } + } + + // Mark blocks live if there is no path from the block to the + // return of the function or a successor for which this is true. + // This protects IDFCalculator which cannot handle such blocks. + for (auto &BBInfoPair : BlockInfo) { + auto &BBInfo = BBInfoPair.second; + if (BBInfo.terminatorIsLive()) + continue; + auto *BB = BBInfo.BB; + if (!PDT.getNode(BB)) { + markLive(BBInfo.Terminator); + continue; + } + for (auto *Succ : successors(BB)) + if (!PDT.getNode(Succ)) { + markLive(BBInfo.Terminator); + break; + } + } + + // Mark blocks live if there is no path from the block to the + // return of the function or a successor for which this is true. + // This protects IDFCalculator which cannot handle such blocks. + for (auto &BBInfoPair : BlockInfo) { + auto &BBInfo = BBInfoPair.second; + if (BBInfo.terminatorIsLive()) + continue; + auto *BB = BBInfo.BB; + if (!PDT.getNode(BB)) { + DEBUG(dbgs() << "Not post-dominated by return: " << BB->getName() + << '\n';); + markLive(BBInfo.Terminator); + continue; + } + for (auto *Succ : successors(BB)) + if (!PDT.getNode(Succ)) { + DEBUG(dbgs() << "Successor not post-dominated by return: " + << BB->getName() << '\n';); + markLive(BBInfo.Terminator); + break; + } + } + + // Treat the entry block as always live + auto *BB = &F.getEntryBlock(); + auto &EntryInfo = BlockInfo[BB]; + EntryInfo.Live = true; + if (EntryInfo.UnconditionalBranch) + markLive(EntryInfo.Terminator); + + // Build initial collection of blocks with dead terminators + for (auto &BBInfo : BlockInfo) + if (!BBInfo.second.terminatorIsLive()) + BlocksWithDeadTerminators.insert(BBInfo.second.BB); +} + +bool AggressiveDeadCodeElimination::isAlwaysLive(Instruction &I) { + // TODO -- use llvm::isInstructionTriviallyDead + if (I.isEHPad() || I.mayHaveSideEffects()) { + // Skip any value profile instrumentation calls if they are + // instrumenting constants. + if (isInstrumentsConstant(I)) + return false; + return true; + } + if (!isa<TerminatorInst>(I)) + return false; + if (RemoveControlFlowFlag && (isa<BranchInst>(I) || isa<SwitchInst>(I))) + return false; + return true; +} + +// Check if this instruction is a runtime call for value profiling and +// if it's instrumenting a constant. +bool AggressiveDeadCodeElimination::isInstrumentsConstant(Instruction &I) { + // TODO -- move this test into llvm::isInstructionTriviallyDead + if (CallInst *CI = dyn_cast<CallInst>(&I)) + if (Function *Callee = CI->getCalledFunction()) + if (Callee->getName().equals(getInstrProfValueProfFuncName())) + if (isa<Constant>(CI->getArgOperand(0))) + return true; + return false; +} + +void AggressiveDeadCodeElimination::markLiveInstructions() { + + // Propagate liveness backwards to operands. + do { + // Worklist holds newly discovered live instructions + // where we need to mark the inputs as live. + while (!Worklist.empty()) { + Instruction *LiveInst = Worklist.pop_back_val(); + DEBUG(dbgs() << "work live: "; LiveInst->dump();); + + for (Use &OI : LiveInst->operands()) + if (Instruction *Inst = dyn_cast<Instruction>(OI)) + markLive(Inst); + + if (auto *PN = dyn_cast<PHINode>(LiveInst)) + markPhiLive(PN); + } -static void collectLiveScopes(const DILocalScope &LS, - SmallPtrSetImpl<const Metadata *> &AliveScopes) { + // After data flow liveness has been identified, examine which branch + // decisions are required to determine live instructions are executed. + markLiveBranchesFromControlDependences(); + + } while (!Worklist.empty()); +} + +void AggressiveDeadCodeElimination::markLive(Instruction *I) { + + auto &Info = InstInfo[I]; + if (Info.Live) + return; + + DEBUG(dbgs() << "mark live: "; I->dump()); + Info.Live = true; + Worklist.push_back(I); + + // Collect the live debug info scopes attached to this instruction. + if (const DILocation *DL = I->getDebugLoc()) + collectLiveScopes(*DL); + + // Mark the containing block live + auto &BBInfo = *Info.Block; + if (BBInfo.Terminator == I) { + BlocksWithDeadTerminators.erase(BBInfo.BB); + // For live terminators, mark destination blocks + // live to preserve this control flow edges. + if (!BBInfo.UnconditionalBranch) + for (auto *BB : successors(I->getParent())) + markLive(BB); + } + markLive(BBInfo); +} + +void AggressiveDeadCodeElimination::markLive(BlockInfoType &BBInfo) { + if (BBInfo.Live) + return; + DEBUG(dbgs() << "mark block live: " << BBInfo.BB->getName() << '\n'); + BBInfo.Live = true; + if (!BBInfo.CFLive) { + BBInfo.CFLive = true; + NewLiveBlocks.insert(BBInfo.BB); + } + + // Mark unconditional branches at the end of live + // blocks as live since there is no work to do for them later + if (BBInfo.UnconditionalBranch) + markLive(BBInfo.Terminator); +} + +void AggressiveDeadCodeElimination::collectLiveScopes(const DILocalScope &LS) { if (!AliveScopes.insert(&LS).second) return; @@ -44,75 +412,115 @@ static void collectLiveScopes(const DILocalScope &LS, return; // Tail-recurse through the scope chain. - collectLiveScopes(cast<DILocalScope>(*LS.getScope()), AliveScopes); + collectLiveScopes(cast<DILocalScope>(*LS.getScope())); } -static void collectLiveScopes(const DILocation &DL, - SmallPtrSetImpl<const Metadata *> &AliveScopes) { +void AggressiveDeadCodeElimination::collectLiveScopes(const DILocation &DL) { // Even though DILocations are not scopes, shove them into AliveScopes so we // don't revisit them. if (!AliveScopes.insert(&DL).second) return; // Collect live scopes from the scope chain. - collectLiveScopes(*DL.getScope(), AliveScopes); + collectLiveScopes(*DL.getScope()); // Tail-recurse through the inlined-at chain. if (const DILocation *IA = DL.getInlinedAt()) - collectLiveScopes(*IA, AliveScopes); + collectLiveScopes(*IA); } -// Check if this instruction is a runtime call for value profiling and -// if it's instrumenting a constant. -static bool isInstrumentsConstant(Instruction &I) { - if (CallInst *CI = dyn_cast<CallInst>(&I)) - if (Function *Callee = CI->getCalledFunction()) - if (Callee->getName().equals(getInstrProfValueProfFuncName())) - if (isa<Constant>(CI->getArgOperand(0))) - return true; - return false; +void AggressiveDeadCodeElimination::markPhiLive(PHINode *PN) { + auto &Info = BlockInfo[PN->getParent()]; + // Only need to check this once per block. + if (Info.HasLivePhiNodes) + return; + Info.HasLivePhiNodes = true; + + // If a predecessor block is not live, mark it as control-flow live + // which will trigger marking live branches upon which + // that block is control dependent. + for (auto *PredBB : predecessors(Info.BB)) { + auto &Info = BlockInfo[PredBB]; + if (!Info.CFLive) { + Info.CFLive = true; + NewLiveBlocks.insert(PredBB); + } + } } -static bool aggressiveDCE(Function& F) { - SmallPtrSet<Instruction*, 32> Alive; - SmallVector<Instruction*, 128> Worklist; +void AggressiveDeadCodeElimination::markLiveBranchesFromControlDependences() { - // Collect the set of "root" instructions that are known live. - for (Instruction &I : instructions(F)) { - if (isa<TerminatorInst>(I) || I.isEHPad() || I.mayHaveSideEffects()) { - // Skip any value profile instrumentation calls if they are - // instrumenting constants. - if (isInstrumentsConstant(I)) - continue; - Alive.insert(&I); - Worklist.push_back(&I); - } + if (BlocksWithDeadTerminators.empty()) + return; + + DEBUG({ + dbgs() << "new live blocks:\n"; + for (auto *BB : NewLiveBlocks) + dbgs() << "\t" << BB->getName() << '\n'; + dbgs() << "dead terminator blocks:\n"; + for (auto *BB : BlocksWithDeadTerminators) + dbgs() << "\t" << BB->getName() << '\n'; + }); + + // The dominance frontier of a live block X in the reverse + // control graph is the set of blocks upon which X is control + // dependent. The following sequence computes the set of blocks + // which currently have dead terminators that are control + // dependence sources of a block which is in NewLiveBlocks. + + SmallVector<BasicBlock *, 32> IDFBlocks; + ReverseIDFCalculator IDFs(PDT); + IDFs.setDefiningBlocks(NewLiveBlocks); + IDFs.setLiveInBlocks(BlocksWithDeadTerminators); + IDFs.calculate(IDFBlocks); + NewLiveBlocks.clear(); + + // Dead terminators which control live blocks are now marked live. + for (auto *BB : IDFBlocks) { + DEBUG(dbgs() << "live control in: " << BB->getName() << '\n'); + markLive(BB->getTerminator()); } +} - // Propagate liveness backwards to operands. Keep track of live debug info - // scopes. - SmallPtrSet<const Metadata *, 32> AliveScopes; - while (!Worklist.empty()) { - Instruction *Curr = Worklist.pop_back_val(); +//===----------------------------------------------------------------------===// +// +// Routines to update the CFG and SSA information before removing dead code. +// +//===----------------------------------------------------------------------===// +bool AggressiveDeadCodeElimination::removeDeadInstructions() { - // Collect the live debug info scopes attached to this instruction. - if (const DILocation *DL = Curr->getDebugLoc()) - collectLiveScopes(*DL, AliveScopes); + // Updates control and dataflow around dead blocks + updateDeadRegions(); - for (Use &OI : Curr->operands()) { - if (Instruction *Inst = dyn_cast<Instruction>(OI)) - if (Alive.insert(Inst).second) - Worklist.push_back(Inst); + DEBUG({ + for (Instruction &I : instructions(F)) { + // Check if the instruction is alive. + if (isLive(&I)) + continue; + + if (auto *DII = dyn_cast<DbgInfoIntrinsic>(&I)) { + // Check if the scope of this variable location is alive. + if (AliveScopes.count(DII->getDebugLoc()->getScope())) + continue; + + // If intrinsic is pointing at a live SSA value, there may be an + // earlier optimization bug: if we know the location of the variable, + // why isn't the scope of the location alive? + if (Value *V = DII->getVariableLocation()) + if (Instruction *II = dyn_cast<Instruction>(V)) + if (isLive(II)) + dbgs() << "Dropping debug info for " << *DII << "\n"; + } } - } + }); // The inverse of the live set is the dead set. These are those instructions - // which have no side effects and do not influence the control flow or return + // that have no side effects and do not influence the control flow or return // value of the function, and may therefore be deleted safely. // NOTE: We reuse the Worklist vector here for memory efficiency. for (Instruction &I : instructions(F)) { // Check if the instruction is alive. - if (Alive.count(&I)) + if (isLive(&I)) continue; if (auto *DII = dyn_cast<DbgInfoIntrinsic>(&I)) { @@ -121,15 +529,6 @@ static bool aggressiveDCE(Function& F) { continue; // Fallthrough and drop the intrinsic. - DEBUG({ - // If intrinsic is pointing at a live SSA value, there may be an - // earlier optimization bug: if we know the location of the variable, - // why isn't the scope of the location alive? - if (Value *V = DII->getVariableLocation()) - if (Instruction *II = dyn_cast<Instruction>(V)) - if (Alive.count(II)) - dbgs() << "Dropping debug info for " << *DII << "\n"; - }); } // Prepare to delete. @@ -145,8 +544,104 @@ static bool aggressiveDCE(Function& F) { return !Worklist.empty(); } -PreservedAnalyses ADCEPass::run(Function &F, FunctionAnalysisManager &) { - if (!aggressiveDCE(F)) +// A dead region is the set of dead blocks with a common live post-dominator. +void AggressiveDeadCodeElimination::updateDeadRegions() { + + DEBUG({ + dbgs() << "final dead terminator blocks: " << '\n'; + for (auto *BB : BlocksWithDeadTerminators) + dbgs() << '\t' << BB->getName() + << (BlockInfo[BB].Live ? " LIVE\n" : "\n"); + }); + + // Don't compute the post ordering unless we needed it. + bool HavePostOrder = false; + + for (auto *BB : BlocksWithDeadTerminators) { + auto &Info = BlockInfo[BB]; + if (Info.UnconditionalBranch) { + InstInfo[Info.Terminator].Live = true; + continue; + } + + if (!HavePostOrder) { + computeReversePostOrder(); + HavePostOrder = true; + } + + // Add an unconditional branch to the successor closest to the + // end of the function which insures a path to the exit for each + // live edge. + BlockInfoType *PreferredSucc = nullptr; + for (auto *Succ : successors(BB)) { + auto *Info = &BlockInfo[Succ]; + if (!PreferredSucc || PreferredSucc->PostOrder < Info->PostOrder) + PreferredSucc = Info; + } + assert((PreferredSucc && PreferredSucc->PostOrder > 0) && + "Failed to find safe successor for dead branc"); + bool First = true; + for (auto *Succ : successors(BB)) { + if (!First || Succ != PreferredSucc->BB) + Succ->removePredecessor(BB); + else + First = false; + } + makeUnconditional(BB, PreferredSucc->BB); + NumBranchesRemoved += 1; + } +} + +// reverse top-sort order +void AggressiveDeadCodeElimination::computeReversePostOrder() { + + // This provides a post-order numbering of the reverse conrtol flow graph + // Note that it is incomplete in the presence of infinite loops but we don't + // need numbers blocks which don't reach the end of the functions since + // all branches in those blocks are forced live. + + // For each block without successors, extend the DFS from the bloack + // backward through the graph + SmallPtrSet<BasicBlock*, 16> Visited; + unsigned PostOrder = 0; + for (auto &BB : F) { + if (succ_begin(&BB) != succ_end(&BB)) + continue; + for (BasicBlock *Block : inverse_post_order_ext(&BB,Visited)) + BlockInfo[Block].PostOrder = PostOrder++; + } +} + +void AggressiveDeadCodeElimination::makeUnconditional(BasicBlock *BB, + BasicBlock *Target) { + TerminatorInst *PredTerm = BB->getTerminator(); + // Collect the live debug info scopes attached to this instruction. + if (const DILocation *DL = PredTerm->getDebugLoc()) + collectLiveScopes(*DL); + + // Just mark live an existing unconditional branch + if (isUnconditionalBranch(PredTerm)) { + PredTerm->setSuccessor(0, Target); + InstInfo[PredTerm].Live = true; + return; + } + DEBUG(dbgs() << "making unconditional " << BB->getName() << '\n'); + NumBranchesRemoved += 1; + IRBuilder<> Builder(PredTerm); + auto *NewTerm = Builder.CreateBr(Target); + InstInfo[NewTerm].Live = true; + if (const DILocation *DL = PredTerm->getDebugLoc()) + NewTerm->setDebugLoc(DL); +} + +//===----------------------------------------------------------------------===// +// +// Pass Manager integration code +// +//===----------------------------------------------------------------------===// +PreservedAnalyses ADCEPass::run(Function &F, FunctionAnalysisManager &FAM) { + auto &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F); + if (!AggressiveDeadCodeElimination(F, PDT).performDeadCodeElimination()) return PreservedAnalyses::all(); // FIXME: This should also 'preserve the CFG'. @@ -162,21 +657,27 @@ struct ADCELegacyPass : public FunctionPass { initializeADCELegacyPassPass(*PassRegistry::getPassRegistry()); } - bool runOnFunction(Function& F) override { + bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; - return aggressiveDCE(F); + auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); + return AggressiveDeadCodeElimination(F, PDT).performDeadCodeElimination(); } - void getAnalysisUsage(AnalysisUsage& AU) const override { - AU.setPreservesCFG(); + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<PostDominatorTreeWrapperPass>(); + if (!RemoveControlFlowFlag) + AU.setPreservesCFG(); AU.addPreserved<GlobalsAAWrapperPass>(); } }; } char ADCELegacyPass::ID = 0; -INITIALIZE_PASS(ADCELegacyPass, "adce", "Aggressive Dead Code Elimination", - false, false) +INITIALIZE_PASS_BEGIN(ADCELegacyPass, "adce", + "Aggressive Dead Code Elimination", false, false) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_END(ADCELegacyPass, "adce", "Aggressive Dead Code Elimination", + false, false) FunctionPass *llvm::createAggressiveDCEPass() { return new ADCELegacyPass(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/contrib/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp index 7f8b8ce..c1df317 100644 --- a/contrib/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -297,6 +297,11 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) { if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV)) return false; + // Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't + // affect other users. + if (isa<ConstantData>(AAPtr)) + return false; + const SCEV *AASCEV = SE->getSCEV(AAPtr); // Apply the assumption to all other users of the specified pointer. @@ -434,6 +439,11 @@ AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) { ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F); DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); bool Changed = runImpl(F, AC, &SE, &DT); + + // FIXME: We need to invalidate this to avoid PR28400. Is there a better + // solution? + AM.invalidate<ScalarEvolutionAnalysis>(F); + if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/contrib/llvm/lib/Transforms/Scalar/BDCE.cpp b/contrib/llvm/lib/Transforms/Scalar/BDCE.cpp index 4f6225f..251b387 100644 --- a/contrib/llvm/lib/Transforms/Scalar/BDCE.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/BDCE.cpp @@ -39,6 +39,12 @@ static bool bitTrackingDCE(Function &F, DemandedBits &DB) { SmallVector<Instruction*, 128> Worklist; bool Changed = false; for (Instruction &I : instructions(F)) { + // If the instruction has side effects and no non-dbg uses, + // skip it. This way we avoid computing known bits on an instruction + // that will not help us. + if (I.mayHaveSideEffects() && I.use_empty()) + continue; + if (I.getType()->isIntegerTy() && !DB.getDemandedBits(&I).getBoolValue()) { // For live instructions that have all dead bits, first make them dead by @@ -50,7 +56,7 @@ static bool bitTrackingDCE(Function &F, DemandedBits &DB) { // undef, poison, etc. Value *Zero = ConstantInt::get(I.getType(), 0); ++NumSimplified; - I.replaceAllUsesWith(Zero); + I.replaceNonMetadataUsesWith(Zero); Changed = true; } if (!DB.isInstructionDead(&I)) diff --git a/contrib/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/contrib/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp index 913e939..3826251 100644 --- a/contrib/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -64,7 +64,7 @@ public: bool runOnFunction(Function &Fn) override; - const char *getPassName() const override { return "Constant Hoisting"; } + StringRef getPassName() const override { return "Constant Hoisting"; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); @@ -444,7 +444,7 @@ void ConstantHoistingPass::findBaseConstants() { /// \brief Updates the operand at Idx in instruction Inst with the result of /// instruction Mat. If the instruction is a PHI node then special -/// handling for duplicate values form the same incomming basic block is +/// handling for duplicate values form the same incoming basic block is /// required. /// \return The update will always succeed, but the return value indicated if /// Mat was used for the update or not. diff --git a/contrib/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/contrib/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index c0fed05..84f9373 100644 --- a/contrib/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -18,6 +18,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyValueInfo.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" @@ -37,8 +38,11 @@ STATISTIC(NumCmps, "Number of comparisons propagated"); STATISTIC(NumReturns, "Number of return values propagated"); STATISTIC(NumDeadCases, "Number of switch cases removed"); STATISTIC(NumSDivs, "Number of sdiv converted to udiv"); +STATISTIC(NumAShrs, "Number of ashr converted to lshr"); STATISTIC(NumSRems, "Number of srem converted to urem"); +static cl::opt<bool> DontProcessAdds("cvp-dont-process-adds", cl::init(true)); + namespace { class CorrelatedValuePropagation : public FunctionPass { public: @@ -381,6 +385,81 @@ static bool processSDiv(BinaryOperator *SDI, LazyValueInfo *LVI) { return true; } +static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { + if (SDI->getType()->isVectorTy() || hasLocalDefs(SDI)) + return false; + + Constant *Zero = ConstantInt::get(SDI->getType(), 0); + if (LVI->getPredicateAt(ICmpInst::ICMP_SGE, SDI->getOperand(0), Zero, SDI) != + LazyValueInfo::True) + return false; + + ++NumAShrs; + auto *BO = BinaryOperator::CreateLShr(SDI->getOperand(0), SDI->getOperand(1), + SDI->getName(), SDI); + BO->setIsExact(SDI->isExact()); + SDI->replaceAllUsesWith(BO); + SDI->eraseFromParent(); + + return true; +} + +static bool processAdd(BinaryOperator *AddOp, LazyValueInfo *LVI) { + typedef OverflowingBinaryOperator OBO; + + if (DontProcessAdds) + return false; + + if (AddOp->getType()->isVectorTy() || hasLocalDefs(AddOp)) + return false; + + bool NSW = AddOp->hasNoSignedWrap(); + bool NUW = AddOp->hasNoUnsignedWrap(); + if (NSW && NUW) + return false; + + BasicBlock *BB = AddOp->getParent(); + + Value *LHS = AddOp->getOperand(0); + Value *RHS = AddOp->getOperand(1); + + ConstantRange LRange = LVI->getConstantRange(LHS, BB, AddOp); + + // Initialize RRange only if we need it. If we know that guaranteed no wrap + // range for the given LHS range is empty don't spend time calculating the + // range for the RHS. + Optional<ConstantRange> RRange; + auto LazyRRange = [&] () { + if (!RRange) + RRange = LVI->getConstantRange(RHS, BB, AddOp); + return RRange.getValue(); + }; + + bool Changed = false; + if (!NUW) { + ConstantRange NUWRange = + LRange.makeGuaranteedNoWrapRegion(BinaryOperator::Add, LRange, + OBO::NoUnsignedWrap); + if (!NUWRange.isEmptySet()) { + bool NewNUW = NUWRange.contains(LazyRRange()); + AddOp->setHasNoUnsignedWrap(NewNUW); + Changed |= NewNUW; + } + } + if (!NSW) { + ConstantRange NSWRange = + LRange.makeGuaranteedNoWrapRegion(BinaryOperator::Add, LRange, + OBO::NoSignedWrap); + if (!NSWRange.isEmptySet()) { + bool NewNSW = NSWRange.contains(LazyRRange()); + AddOp->setHasNoSignedWrap(NewNSW); + Changed |= NewNSW; + } + } + + return Changed; +} + static Constant *getConstantAt(Value *V, Instruction *At, LazyValueInfo *LVI) { if (Constant *C = LVI->getConstant(V, At->getParent(), At)) return C; @@ -407,9 +486,14 @@ static Constant *getConstantAt(Value *V, Instruction *At, LazyValueInfo *LVI) { static bool runImpl(Function &F, LazyValueInfo *LVI) { bool FnChanged = false; - for (BasicBlock &BB : F) { + // Visiting in a pre-order depth-first traversal causes us to simplify early + // blocks before querying later blocks (which require us to analyze early + // blocks). Eagerly simplifying shallow blocks means there is strictly less + // work to do for deep blocks. This also means we don't visit unreachable + // blocks. + for (BasicBlock *BB : depth_first(&F.getEntryBlock())) { bool BBChanged = false; - for (BasicBlock::iterator BI = BB.begin(), BE = BB.end(); BI != BE;) { + for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { Instruction *II = &*BI++; switch (II->getOpcode()) { case Instruction::Select: @@ -436,10 +520,16 @@ static bool runImpl(Function &F, LazyValueInfo *LVI) { case Instruction::SDiv: BBChanged |= processSDiv(cast<BinaryOperator>(II), LVI); break; + case Instruction::AShr: + BBChanged |= processAShr(cast<BinaryOperator>(II), LVI); + break; + case Instruction::Add: + BBChanged |= processAdd(cast<BinaryOperator>(II), LVI); + break; } } - Instruction *Term = BB.getTerminator(); + Instruction *Term = BB->getTerminator(); switch (Term->getOpcode()) { case Instruction::Switch: BBChanged |= processSwitch(cast<SwitchInst>(Term), LVI); diff --git a/contrib/llvm/lib/Transforms/Scalar/DCE.cpp b/contrib/llvm/lib/Transforms/Scalar/DCE.cpp index f73809d..cc2a3cf 100644 --- a/contrib/llvm/lib/Transforms/Scalar/DCE.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/DCE.cpp @@ -123,7 +123,7 @@ static bool eliminateDeadCode(Function &F, TargetLibraryInfo *TLI) { return MadeChange; } -PreservedAnalyses DCEPass::run(Function &F, AnalysisManager<Function> &AM) { +PreservedAnalyses DCEPass::run(Function &F, FunctionAnalysisManager &AM) { if (eliminateDeadCode(F, AM.getCachedResult<TargetLibraryAnalysis>(F))) return PreservedAnalyses::none(); return PreservedAnalyses::all(); diff --git a/contrib/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/contrib/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp index ed58a87..4d4c3ba 100644 --- a/contrib/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -59,6 +59,8 @@ EnablePartialOverwriteTracking("enable-dse-partial-overwrite-tracking", //===----------------------------------------------------------------------===// // Helper functions //===----------------------------------------------------------------------===// +typedef std::map<int64_t, int64_t> OverlapIntervalsTy; +typedef DenseMap<Instruction *, OverlapIntervalsTy> InstOverlapIntervalsTy; /// Delete this instruction. Before we do, go through and zero out all the /// operands of this instruction. If any of them become dead, delete them and @@ -67,6 +69,8 @@ EnablePartialOverwriteTracking("enable-dse-partial-overwrite-tracking", static void deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI, MemoryDependenceResults &MD, const TargetLibraryInfo &TLI, + InstOverlapIntervalsTy &IOL, + DenseMap<Instruction*, size_t> *InstrOrdering, SmallSetVector<Value *, 16> *ValueSet = nullptr) { SmallVector<Instruction*, 32> NowDeadInsts; @@ -99,13 +103,14 @@ deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI, NowDeadInsts.push_back(OpI); } + if (ValueSet) ValueSet->remove(DeadInst); + InstrOrdering->erase(DeadInst); + IOL.erase(DeadInst); if (NewIter == DeadInst->getIterator()) NewIter = DeadInst->eraseFromParent(); else DeadInst->eraseFromParent(); - - if (ValueSet) ValueSet->remove(DeadInst); } while (!NowDeadInsts.empty()); *BBI = NewIter; } @@ -290,9 +295,6 @@ enum OverwriteResult { }; } -typedef DenseMap<Instruction *, - std::map<int64_t, int64_t>> InstOverlapIntervalsTy; - /// Return 'OverwriteComplete' if a store to the 'Later' location completely /// overwrites a store to the 'Earlier' location, 'OverwriteEnd' if the end of /// the 'Earlier' location is completely overwritten by 'Later', @@ -438,9 +440,9 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, // // In this case we may want to trim the size of earlier to avoid generating // writes to addresses which will definitely be overwritten later - if (LaterOff > EarlierOff && - LaterOff < int64_t(EarlierOff + Earlier.Size) && - int64_t(LaterOff + Later.Size) >= int64_t(EarlierOff + Earlier.Size)) + if (!EnablePartialOverwriteTracking && + (LaterOff > EarlierOff && LaterOff < int64_t(EarlierOff + Earlier.Size) && + int64_t(LaterOff + Later.Size) >= int64_t(EarlierOff + Earlier.Size))) return OverwriteEnd; // Finally, we also need to check if the later store overwrites the beginning @@ -452,9 +454,11 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, // In this case we may want to move the destination address and trim the size // of earlier to avoid generating writes to addresses which will definitely // be overwritten later. - if (LaterOff <= EarlierOff && int64_t(LaterOff + Later.Size) > EarlierOff) { - assert (int64_t(LaterOff + Later.Size) < int64_t(EarlierOff + Earlier.Size) - && "Expect to be handled as OverwriteComplete" ); + if (!EnablePartialOverwriteTracking && + (LaterOff <= EarlierOff && int64_t(LaterOff + Later.Size) > EarlierOff)) { + assert(int64_t(LaterOff + Later.Size) < + int64_t(EarlierOff + Earlier.Size) && + "Expect to be handled as OverwriteComplete"); return OverwriteBegin; } // Otherwise, they don't completely overlap. @@ -505,7 +509,6 @@ static bool isPossibleSelfRead(Instruction *Inst, return true; } - /// Returns true if the memory which is accessed by the second instruction is not /// modified between the first and the second instruction. /// Precondition: Second instruction must be dominated by the first @@ -585,7 +588,9 @@ static void findUnconditionalPreds(SmallVectorImpl<BasicBlock *> &Blocks, /// to a field of that structure. static bool handleFree(CallInst *F, AliasAnalysis *AA, MemoryDependenceResults *MD, DominatorTree *DT, - const TargetLibraryInfo *TLI) { + const TargetLibraryInfo *TLI, + InstOverlapIntervalsTy &IOL, + DenseMap<Instruction*, size_t> *InstrOrdering) { bool MadeChange = false; MemoryLocation Loc = MemoryLocation(F->getOperand(0)); @@ -612,9 +617,12 @@ static bool handleFree(CallInst *F, AliasAnalysis *AA, if (!AA->isMustAlias(F->getArgOperand(0), DepPointer)) break; + DEBUG(dbgs() << "DSE: Dead Store to soon to be freed memory:\n DEAD: " + << *Dependency << '\n'); + // DCE instructions only used to calculate that store. BasicBlock::iterator BBI(Dependency); - deleteDeadInstruction(Dependency, &BBI, *MD, *TLI); + deleteDeadInstruction(Dependency, &BBI, *MD, *TLI, IOL, InstrOrdering); ++NumFastStores; MadeChange = true; @@ -669,7 +677,9 @@ static void removeAccessedObjects(const MemoryLocation &LoadedLoc, /// ret void static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, MemoryDependenceResults *MD, - const TargetLibraryInfo *TLI) { + const TargetLibraryInfo *TLI, + InstOverlapIntervalsTy &IOL, + DenseMap<Instruction*, size_t> *InstrOrdering) { bool MadeChange = false; // Keep track of all of the stack objects that are dead at the end of the @@ -728,7 +738,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, dbgs() << '\n'); // DCE instructions only used to calculate that store. - deleteDeadInstruction(Dead, &BBI, *MD, *TLI, &DeadStackObjects); + deleteDeadInstruction(Dead, &BBI, *MD, *TLI, IOL, InstrOrdering, &DeadStackObjects); ++NumFastStores; MadeChange = true; continue; @@ -737,7 +747,9 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, // Remove any dead non-memory-mutating instructions. if (isInstructionTriviallyDead(&*BBI, TLI)) { - deleteDeadInstruction(&*BBI, &BBI, *MD, *TLI, &DeadStackObjects); + DEBUG(dbgs() << "DSE: Removing trivially dead instruction:\n DEAD: " + << *&*BBI << '\n'); + deleteDeadInstruction(&*BBI, &BBI, *MD, *TLI, IOL, InstrOrdering, &DeadStackObjects); ++NumFastOther; MadeChange = true; continue; @@ -819,10 +831,125 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, return MadeChange; } +static bool tryToShorten(Instruction *EarlierWrite, int64_t &EarlierOffset, + int64_t &EarlierSize, int64_t LaterOffset, + int64_t LaterSize, bool IsOverwriteEnd) { + // TODO: base this on the target vector size so that if the earlier + // store was too small to get vector writes anyway then its likely + // a good idea to shorten it + // Power of 2 vector writes are probably always a bad idea to optimize + // as any store/memset/memcpy is likely using vector instructions so + // shortening it to not vector size is likely to be slower + MemIntrinsic *EarlierIntrinsic = cast<MemIntrinsic>(EarlierWrite); + unsigned EarlierWriteAlign = EarlierIntrinsic->getAlignment(); + if (!IsOverwriteEnd) + LaterOffset = int64_t(LaterOffset + LaterSize); + + if (!(llvm::isPowerOf2_64(LaterOffset) && EarlierWriteAlign <= LaterOffset) && + !((EarlierWriteAlign != 0) && LaterOffset % EarlierWriteAlign == 0)) + return false; + + DEBUG(dbgs() << "DSE: Remove Dead Store:\n OW " + << (IsOverwriteEnd ? "END" : "BEGIN") << ": " << *EarlierWrite + << "\n KILLER (offset " << LaterOffset << ", " << EarlierSize + << ")\n"); + + int64_t NewLength = IsOverwriteEnd + ? LaterOffset - EarlierOffset + : EarlierSize - (LaterOffset - EarlierOffset); + + Value *EarlierWriteLength = EarlierIntrinsic->getLength(); + Value *TrimmedLength = + ConstantInt::get(EarlierWriteLength->getType(), NewLength); + EarlierIntrinsic->setLength(TrimmedLength); + + EarlierSize = NewLength; + if (!IsOverwriteEnd) { + int64_t OffsetMoved = (LaterOffset - EarlierOffset); + Value *Indices[1] = { + ConstantInt::get(EarlierWriteLength->getType(), OffsetMoved)}; + GetElementPtrInst *NewDestGEP = GetElementPtrInst::CreateInBounds( + EarlierIntrinsic->getRawDest(), Indices, "", EarlierWrite); + EarlierIntrinsic->setDest(NewDestGEP); + EarlierOffset = EarlierOffset + OffsetMoved; + } + return true; +} + +static bool tryToShortenEnd(Instruction *EarlierWrite, + OverlapIntervalsTy &IntervalMap, + int64_t &EarlierStart, int64_t &EarlierSize) { + if (IntervalMap.empty() || !isShortenableAtTheEnd(EarlierWrite)) + return false; + + OverlapIntervalsTy::iterator OII = --IntervalMap.end(); + int64_t LaterStart = OII->second; + int64_t LaterSize = OII->first - LaterStart; + + if (LaterStart > EarlierStart && LaterStart < EarlierStart + EarlierSize && + LaterStart + LaterSize >= EarlierStart + EarlierSize) { + if (tryToShorten(EarlierWrite, EarlierStart, EarlierSize, LaterStart, + LaterSize, true)) { + IntervalMap.erase(OII); + return true; + } + } + return false; +} + +static bool tryToShortenBegin(Instruction *EarlierWrite, + OverlapIntervalsTy &IntervalMap, + int64_t &EarlierStart, int64_t &EarlierSize) { + if (IntervalMap.empty() || !isShortenableAtTheBeginning(EarlierWrite)) + return false; + + OverlapIntervalsTy::iterator OII = IntervalMap.begin(); + int64_t LaterStart = OII->second; + int64_t LaterSize = OII->first - LaterStart; + + if (LaterStart <= EarlierStart && LaterStart + LaterSize > EarlierStart) { + assert(LaterStart + LaterSize < EarlierStart + EarlierSize && + "Should have been handled as OverwriteComplete"); + if (tryToShorten(EarlierWrite, EarlierStart, EarlierSize, LaterStart, + LaterSize, false)) { + IntervalMap.erase(OII); + return true; + } + } + return false; +} + +static bool removePartiallyOverlappedStores(AliasAnalysis *AA, + const DataLayout &DL, + InstOverlapIntervalsTy &IOL) { + bool Changed = false; + for (auto OI : IOL) { + Instruction *EarlierWrite = OI.first; + MemoryLocation Loc = getLocForWrite(EarlierWrite, *AA); + assert(isRemovable(EarlierWrite) && "Expect only removable instruction"); + assert(Loc.Size != MemoryLocation::UnknownSize && "Unexpected mem loc"); + + const Value *Ptr = Loc.Ptr->stripPointerCasts(); + int64_t EarlierStart = 0; + int64_t EarlierSize = int64_t(Loc.Size); + GetPointerBaseWithConstantOffset(Ptr, EarlierStart, DL); + OverlapIntervalsTy &IntervalMap = OI.second; + Changed |= + tryToShortenEnd(EarlierWrite, IntervalMap, EarlierStart, EarlierSize); + if (IntervalMap.empty()) + continue; + Changed |= + tryToShortenBegin(EarlierWrite, IntervalMap, EarlierStart, EarlierSize); + } + return Changed; +} + static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI, AliasAnalysis *AA, MemoryDependenceResults *MD, const DataLayout &DL, - const TargetLibraryInfo *TLI) { + const TargetLibraryInfo *TLI, + InstOverlapIntervalsTy &IOL, + DenseMap<Instruction*, size_t> *InstrOrdering) { // Must be a store instruction. StoreInst *SI = dyn_cast<StoreInst>(Inst); if (!SI) @@ -837,7 +964,7 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI, DEBUG(dbgs() << "DSE: Remove Store Of Load from same pointer:\n LOAD: " << *DepLoad << "\n STORE: " << *SI << '\n'); - deleteDeadInstruction(SI, &BBI, *MD, *TLI); + deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, InstrOrdering); ++NumRedundantStores; return true; } @@ -855,7 +982,7 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI, dbgs() << "DSE: Remove null store to the calloc'ed object:\n DEAD: " << *Inst << "\n OBJECT: " << *UnderlyingPointer << '\n'); - deleteDeadInstruction(SI, &BBI, *MD, *TLI); + deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, InstrOrdering); ++NumRedundantStores; return true; } @@ -869,6 +996,12 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, const DataLayout &DL = BB.getModule()->getDataLayout(); bool MadeChange = false; + // FIXME: Maybe change this to use some abstraction like OrderedBasicBlock? + // The current OrderedBasicBlock can't deal with mutation at the moment. + size_t LastThrowingInstIndex = 0; + DenseMap<Instruction*, size_t> InstrOrdering; + size_t InstrIndex = 1; + // A map of interval maps representing partially-overwritten value parts. InstOverlapIntervalsTy IOL; @@ -876,7 +1009,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, for (BasicBlock::iterator BBI = BB.begin(), BBE = BB.end(); BBI != BBE; ) { // Handle 'free' calls specially. if (CallInst *F = isFreeCall(&*BBI, TLI)) { - MadeChange |= handleFree(F, AA, MD, DT, TLI); + MadeChange |= handleFree(F, AA, MD, DT, TLI, IOL, &InstrOrdering); // Increment BBI after handleFree has potentially deleted instructions. // This ensures we maintain a valid iterator. ++BBI; @@ -885,12 +1018,19 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, Instruction *Inst = &*BBI++; + size_t CurInstNumber = InstrIndex++; + InstrOrdering.insert(std::make_pair(Inst, CurInstNumber)); + if (Inst->mayThrow()) { + LastThrowingInstIndex = CurInstNumber; + continue; + } + // Check to see if Inst writes to memory. If not, continue. if (!hasMemoryWrite(Inst, *TLI)) continue; // eliminateNoopStore will update in iterator, if necessary. - if (eliminateNoopStore(Inst, BBI, AA, MD, DL, TLI)) { + if (eliminateNoopStore(Inst, BBI, AA, MD, DL, TLI, IOL, &InstrOrdering)) { MadeChange = true; continue; } @@ -910,6 +1050,13 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, if (!Loc.Ptr) continue; + // Loop until we find a store we can eliminate or a load that + // invalidates the analysis. Without an upper bound on the number of + // instructions examined, this analysis can become very time-consuming. + // However, the potential gain diminishes as we process more instructions + // without eliminating any of them. Therefore, we limit the number of + // instructions we look at. + auto Limit = MD->getDefaultBlockScanLimit(); while (InstDep.isDef() || InstDep.isClobber()) { // Get the memory clobbered by the instruction we depend on. MemDep will // skip any instructions that 'Loc' clearly doesn't interact with. If we @@ -924,6 +1071,31 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, if (!DepLoc.Ptr) break; + // Make sure we don't look past a call which might throw. This is an + // issue because MemoryDependenceAnalysis works in the wrong direction: + // it finds instructions which dominate the current instruction, rather than + // instructions which are post-dominated by the current instruction. + // + // If the underlying object is a non-escaping memory allocation, any store + // to it is dead along the unwind edge. Otherwise, we need to preserve + // the store. + size_t DepIndex = InstrOrdering.lookup(DepWrite); + assert(DepIndex && "Unexpected instruction"); + if (DepIndex <= LastThrowingInstIndex) { + const Value* Underlying = GetUnderlyingObject(DepLoc.Ptr, DL); + bool IsStoreDeadOnUnwind = isa<AllocaInst>(Underlying); + if (!IsStoreDeadOnUnwind) { + // We're looking for a call to an allocation function + // where the allocation doesn't escape before the last + // throwing instruction; PointerMayBeCaptured + // reasonably fast approximation. + IsStoreDeadOnUnwind = isAllocLikeFn(Underlying, TLI) && + !PointerMayBeCaptured(Underlying, false, true); + } + if (!IsStoreDeadOnUnwind) + break; + } + // If we find a write that is a) removable (i.e., non-volatile), b) is // completely obliterated by the store to 'Loc', and c) which we know that // 'Inst' doesn't load from, then we can remove it. @@ -938,7 +1110,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, << *DepWrite << "\n KILLER: " << *Inst << '\n'); // Delete the store and now-dead instructions that feed it. - deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI); + deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, &InstrOrdering); ++NumFastStores; MadeChange = true; @@ -948,48 +1120,14 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, } else if ((OR == OverwriteEnd && isShortenableAtTheEnd(DepWrite)) || ((OR == OverwriteBegin && isShortenableAtTheBeginning(DepWrite)))) { - // TODO: base this on the target vector size so that if the earlier - // store was too small to get vector writes anyway then its likely - // a good idea to shorten it - // Power of 2 vector writes are probably always a bad idea to optimize - // as any store/memset/memcpy is likely using vector instructions so - // shortening it to not vector size is likely to be slower - MemIntrinsic *DepIntrinsic = cast<MemIntrinsic>(DepWrite); - unsigned DepWriteAlign = DepIntrinsic->getAlignment(); + assert(!EnablePartialOverwriteTracking && "Do not expect to perform " + "when partial-overwrite " + "tracking is enabled"); + int64_t EarlierSize = DepLoc.Size; + int64_t LaterSize = Loc.Size; bool IsOverwriteEnd = (OR == OverwriteEnd); - if (!IsOverwriteEnd) - InstWriteOffset = int64_t(InstWriteOffset + Loc.Size); - - if ((llvm::isPowerOf2_64(InstWriteOffset) && - DepWriteAlign <= InstWriteOffset) || - ((DepWriteAlign != 0) && InstWriteOffset % DepWriteAlign == 0)) { - - DEBUG(dbgs() << "DSE: Remove Dead Store:\n OW " - << (IsOverwriteEnd ? "END" : "BEGIN") << ": " - << *DepWrite << "\n KILLER (offset " - << InstWriteOffset << ", " << DepLoc.Size << ")" - << *Inst << '\n'); - - int64_t NewLength = - IsOverwriteEnd - ? InstWriteOffset - DepWriteOffset - : DepLoc.Size - (InstWriteOffset - DepWriteOffset); - - Value *DepWriteLength = DepIntrinsic->getLength(); - Value *TrimmedLength = - ConstantInt::get(DepWriteLength->getType(), NewLength); - DepIntrinsic->setLength(TrimmedLength); - - if (!IsOverwriteEnd) { - int64_t OffsetMoved = (InstWriteOffset - DepWriteOffset); - Value *Indices[1] = { - ConstantInt::get(DepWriteLength->getType(), OffsetMoved)}; - GetElementPtrInst *NewDestGEP = GetElementPtrInst::CreateInBounds( - DepIntrinsic->getRawDest(), Indices, "", DepWrite); - DepIntrinsic->setDest(NewDestGEP); - } - MadeChange = true; - } + MadeChange |= tryToShorten(DepWrite, DepWriteOffset, EarlierSize, + InstWriteOffset, LaterSize, IsOverwriteEnd); } } @@ -1007,15 +1145,19 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, if (AA->getModRefInfo(DepWrite, Loc) & MRI_Ref) break; - InstDep = MD->getPointerDependencyFrom(Loc, false, - DepWrite->getIterator(), &BB); + InstDep = MD->getPointerDependencyFrom(Loc, /*isLoad=*/ false, + DepWrite->getIterator(), &BB, + /*QueryInst=*/ nullptr, &Limit); } } + if (EnablePartialOverwriteTracking) + MadeChange |= removePartiallyOverlappedStores(AA, DL, IOL); + // If this block ends in a return, unwind, or unreachable, all allocas are // dead at its end, which means stores to them are also dead. if (BB.getTerminator()->getNumSuccessors() == 0) - MadeChange |= handleEndBlock(BB, AA, MD, TLI); + MadeChange |= handleEndBlock(BB, AA, MD, TLI, IOL, &InstrOrdering); return MadeChange; } @@ -1029,6 +1171,7 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis *AA, // cycles that will confuse alias analysis. if (DT->isReachableFromEntry(&BB)) MadeChange |= eliminateDeadStores(BB, AA, MD, DT, TLI); + return MadeChange; } diff --git a/contrib/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/contrib/llvm/lib/Transforms/Scalar/EarlyCSE.cpp index 0b16e27..16e08ee 100644 --- a/contrib/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -32,6 +32,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/MemorySSA.h" #include <deque> using namespace llvm; using namespace llvm::PatternMatch; @@ -251,6 +252,7 @@ public: const TargetTransformInfo &TTI; DominatorTree &DT; AssumptionCache &AC; + MemorySSA *MSSA; typedef RecyclingAllocator< BumpPtrAllocator, ScopedHashTableVal<SimpleValue, Value *>> AllocatorTy; typedef ScopedHashTable<SimpleValue, Value *, DenseMapInfo<SimpleValue>, @@ -312,8 +314,8 @@ public: /// \brief Set up the EarlyCSE runner for a particular function. EarlyCSE(const TargetLibraryInfo &TLI, const TargetTransformInfo &TTI, - DominatorTree &DT, AssumptionCache &AC) - : TLI(TLI), TTI(TTI), DT(DT), AC(AC), CurrentGeneration(0) {} + DominatorTree &DT, AssumptionCache &AC, MemorySSA *MSSA) + : TLI(TLI), TTI(TTI), DT(DT), AC(AC), MSSA(MSSA), CurrentGeneration(0) {} bool run(); @@ -338,7 +340,7 @@ private: }; // Contains all the needed information to create a stack for doing a depth - // first tranversal of the tree. This includes scopes for values, loads, and + // first traversal of the tree. This includes scopes for values, loads, and // calls as well as the generation. There is a child iterator so that the // children do not need to be store separately. class StackNode { @@ -479,17 +481,93 @@ private: bool processNode(DomTreeNode *Node); Value *getOrCreateResult(Value *Inst, Type *ExpectedType) const { - if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) + if (auto *LI = dyn_cast<LoadInst>(Inst)) return LI; - else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) + if (auto *SI = dyn_cast<StoreInst>(Inst)) return SI->getValueOperand(); assert(isa<IntrinsicInst>(Inst) && "Instruction not supported"); return TTI.getOrCreateResultFromMemIntrinsic(cast<IntrinsicInst>(Inst), ExpectedType); } + + bool isSameMemGeneration(unsigned EarlierGeneration, unsigned LaterGeneration, + Instruction *EarlierInst, Instruction *LaterInst); + + void removeMSSA(Instruction *Inst) { + if (!MSSA) + return; + // Removing a store here can leave MemorySSA in an unoptimized state by + // creating MemoryPhis that have identical arguments and by creating + // MemoryUses whose defining access is not an actual clobber. We handle the + // phi case eagerly here. The non-optimized MemoryUse case is lazily + // updated by MemorySSA getClobberingMemoryAccess. + if (MemoryAccess *MA = MSSA->getMemoryAccess(Inst)) { + // Optimize MemoryPhi nodes that may become redundant by having all the + // same input values once MA is removed. + SmallVector<MemoryPhi *, 4> PhisToCheck; + SmallVector<MemoryAccess *, 8> WorkQueue; + WorkQueue.push_back(MA); + // Process MemoryPhi nodes in FIFO order using a ever-growing vector since + // we shouldn't be processing that many phis and this will avoid an + // allocation in almost all cases. + for (unsigned I = 0; I < WorkQueue.size(); ++I) { + MemoryAccess *WI = WorkQueue[I]; + + for (auto *U : WI->users()) + if (MemoryPhi *MP = dyn_cast<MemoryPhi>(U)) + PhisToCheck.push_back(MP); + + MSSA->removeMemoryAccess(WI); + + for (MemoryPhi *MP : PhisToCheck) { + MemoryAccess *FirstIn = MP->getIncomingValue(0); + if (all_of(MP->incoming_values(), + [=](Use &In) { return In == FirstIn; })) + WorkQueue.push_back(MP); + } + PhisToCheck.clear(); + } + } + } }; } +/// Determine if the memory referenced by LaterInst is from the same heap +/// version as EarlierInst. +/// This is currently called in two scenarios: +/// +/// load p +/// ... +/// load p +/// +/// and +/// +/// x = load p +/// ... +/// store x, p +/// +/// in both cases we want to verify that there are no possible writes to the +/// memory referenced by p between the earlier and later instruction. +bool EarlyCSE::isSameMemGeneration(unsigned EarlierGeneration, + unsigned LaterGeneration, + Instruction *EarlierInst, + Instruction *LaterInst) { + // Check the simple memory generation tracking first. + if (EarlierGeneration == LaterGeneration) + return true; + + if (!MSSA) + return false; + + // Since we know LaterDef dominates LaterInst and EarlierInst dominates + // LaterInst, if LaterDef dominates EarlierInst then it can't occur between + // EarlierInst and LaterInst and neither can any other write that potentially + // clobbers LaterInst. + MemoryAccess *LaterDef = + MSSA->getWalker()->getClobberingMemoryAccess(LaterInst); + return MSSA->dominates(LaterDef, MSSA->getMemoryAccess(EarlierInst)); +} + bool EarlyCSE::processNode(DomTreeNode *Node) { bool Changed = false; BasicBlock *BB = Node->getBlock(); @@ -547,6 +625,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // Dead instructions should just be removed. if (isInstructionTriviallyDead(Inst, &TLI)) { DEBUG(dbgs() << "EarlyCSE DCE: " << *Inst << '\n'); + removeMSSA(Inst); Inst->eraseFromParent(); Changed = true; ++NumSimplify; @@ -562,6 +641,19 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { continue; } + // Skip invariant.start intrinsics since they only read memory, and we can + // forward values across it. Also, we dont need to consume the last store + // since the semantics of invariant.start allow us to perform DSE of the + // last store, if there was a store following invariant.start. Consider: + // + // store 30, i8* p + // invariant.start(p) + // store 40, i8* p + // We can DSE the store to 30, since the store 40 to invariant location p + // causes undefined behaviour. + if (match(Inst, m_Intrinsic<Intrinsic::invariant_start>())) + continue; + if (match(Inst, m_Intrinsic<Intrinsic::experimental_guard>())) { if (auto *CondI = dyn_cast<Instruction>(cast<CallInst>(Inst)->getArgOperand(0))) { @@ -588,6 +680,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { Changed = true; } if (isInstructionTriviallyDead(Inst, &TLI)) { + removeMSSA(Inst); Inst->eraseFromParent(); Changed = true; Killed = true; @@ -606,6 +699,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { if (auto *I = dyn_cast<Instruction>(V)) I->andIRFlags(Inst); Inst->replaceAllUsesWith(V); + removeMSSA(Inst); Inst->eraseFromParent(); Changed = true; ++NumCSE; @@ -631,24 +725,26 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // generation or the load is known to be from an invariant location, // replace this instruction. // - // A dominating invariant load implies that the location loaded from is - // unchanging beginning at the point of the invariant load, so the load - // we're CSE'ing _away_ does not need to be invariant, only the available - // load we're CSE'ing _to_ does. + // If either the dominating load or the current load are invariant, then + // we can assume the current load loads the same value as the dominating + // load. LoadValue InVal = AvailableLoads.lookup(MemInst.getPointerOperand()); if (InVal.DefInst != nullptr && - (InVal.Generation == CurrentGeneration || InVal.IsInvariant) && InVal.MatchingId == MemInst.getMatchingId() && // We don't yet handle removing loads with ordering of any kind. !MemInst.isVolatile() && MemInst.isUnordered() && // We can't replace an atomic load with one which isn't also atomic. - InVal.IsAtomic >= MemInst.isAtomic()) { + InVal.IsAtomic >= MemInst.isAtomic() && + (InVal.IsInvariant || MemInst.isInvariantLoad() || + isSameMemGeneration(InVal.Generation, CurrentGeneration, + InVal.DefInst, Inst))) { Value *Op = getOrCreateResult(InVal.DefInst, Inst->getType()); if (Op != nullptr) { DEBUG(dbgs() << "EarlyCSE CSE LOAD: " << *Inst << " to: " << *InVal.DefInst << '\n'); if (!Inst->use_empty()) Inst->replaceAllUsesWith(Op); + removeMSSA(Inst); Inst->eraseFromParent(); Changed = true; ++NumCSELoad; @@ -679,11 +775,14 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { // If we have an available version of this call, and if it is the right // generation, replace this instruction. std::pair<Instruction *, unsigned> InVal = AvailableCalls.lookup(Inst); - if (InVal.first != nullptr && InVal.second == CurrentGeneration) { + if (InVal.first != nullptr && + isSameMemGeneration(InVal.second, CurrentGeneration, InVal.first, + Inst)) { DEBUG(dbgs() << "EarlyCSE CSE CALL: " << *Inst << " to: " << *InVal.first << '\n'); if (!Inst->use_empty()) Inst->replaceAllUsesWith(InVal.first); + removeMSSA(Inst); Inst->eraseFromParent(); Changed = true; ++NumCSECall; @@ -716,15 +815,22 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { LoadValue InVal = AvailableLoads.lookup(MemInst.getPointerOperand()); if (InVal.DefInst && InVal.DefInst == getOrCreateResult(Inst, InVal.DefInst->getType()) && - InVal.Generation == CurrentGeneration && InVal.MatchingId == MemInst.getMatchingId() && // We don't yet handle removing stores with ordering of any kind. - !MemInst.isVolatile() && MemInst.isUnordered()) { + !MemInst.isVolatile() && MemInst.isUnordered() && + isSameMemGeneration(InVal.Generation, CurrentGeneration, + InVal.DefInst, Inst)) { + // It is okay to have a LastStore to a different pointer here if MemorySSA + // tells us that the load and store are from the same memory generation. + // In that case, LastStore should keep its present value since we're + // removing the current store. assert((!LastStore || ParseMemoryInst(LastStore, TTI).getPointerOperand() == - MemInst.getPointerOperand()) && - "can't have an intervening store!"); + MemInst.getPointerOperand() || + MSSA) && + "can't have an intervening store if not using MemorySSA!"); DEBUG(dbgs() << "EarlyCSE DSE (writeback): " << *Inst << '\n'); + removeMSSA(Inst); Inst->eraseFromParent(); Changed = true; ++NumDSE; @@ -756,6 +862,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { if (LastStoreMemInst.isMatchingMemLoc(MemInst)) { DEBUG(dbgs() << "EarlyCSE DEAD STORE: " << *LastStore << " due to: " << *Inst << '\n'); + removeMSSA(LastStore); LastStore->eraseFromParent(); Changed = true; ++NumDSE; @@ -847,13 +954,15 @@ bool EarlyCSE::run() { } PreservedAnalyses EarlyCSEPass::run(Function &F, - AnalysisManager<Function> &AM) { + FunctionAnalysisManager &AM) { auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &TTI = AM.getResult<TargetIRAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto *MSSA = + UseMemorySSA ? &AM.getResult<MemorySSAAnalysis>(F).getMSSA() : nullptr; - EarlyCSE CSE(TLI, TTI, DT, AC); + EarlyCSE CSE(TLI, TTI, DT, AC, MSSA); if (!CSE.run()) return PreservedAnalyses::all(); @@ -863,6 +972,8 @@ PreservedAnalyses EarlyCSEPass::run(Function &F, PreservedAnalyses PA; PA.preserve<DominatorTreeAnalysis>(); PA.preserve<GlobalsAA>(); + if (UseMemorySSA) + PA.preserve<MemorySSAAnalysis>(); return PA; } @@ -874,12 +985,16 @@ namespace { /// canonicalize things as it goes. It is intended to be fast and catch obvious /// cases so that instcombine and other passes are more effective. It is /// expected that a later pass of GVN will catch the interesting/hard cases. -class EarlyCSELegacyPass : public FunctionPass { +template<bool UseMemorySSA> +class EarlyCSELegacyCommonPass : public FunctionPass { public: static char ID; - EarlyCSELegacyPass() : FunctionPass(ID) { - initializeEarlyCSELegacyPassPass(*PassRegistry::getPassRegistry()); + EarlyCSELegacyCommonPass() : FunctionPass(ID) { + if (UseMemorySSA) + initializeEarlyCSEMemSSALegacyPassPass(*PassRegistry::getPassRegistry()); + else + initializeEarlyCSELegacyPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override { @@ -890,8 +1005,10 @@ public: auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto *MSSA = + UseMemorySSA ? &getAnalysis<MemorySSAWrapperPass>().getMSSA() : nullptr; - EarlyCSE CSE(TLI, TTI, DT, AC); + EarlyCSE CSE(TLI, TTI, DT, AC, MSSA); return CSE.run(); } @@ -901,15 +1018,20 @@ public: AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); + if (UseMemorySSA) { + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); + } AU.addPreserved<GlobalsAAWrapperPass>(); AU.setPreservesCFG(); } }; } -char EarlyCSELegacyPass::ID = 0; +using EarlyCSELegacyPass = EarlyCSELegacyCommonPass</*UseMemorySSA=*/false>; -FunctionPass *llvm::createEarlyCSEPass() { return new EarlyCSELegacyPass(); } +template<> +char EarlyCSELegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(EarlyCSELegacyPass, "early-cse", "Early CSE", false, false) @@ -918,3 +1040,26 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(EarlyCSELegacyPass, "early-cse", "Early CSE", false, false) + +using EarlyCSEMemSSALegacyPass = + EarlyCSELegacyCommonPass</*UseMemorySSA=*/true>; + +template<> +char EarlyCSEMemSSALegacyPass::ID = 0; + +FunctionPass *llvm::createEarlyCSEPass(bool UseMemorySSA) { + if (UseMemorySSA) + return new EarlyCSEMemSSALegacyPass(); + else + return new EarlyCSELegacyPass(); +} + +INITIALIZE_PASS_BEGIN(EarlyCSEMemSSALegacyPass, "early-cse-memssa", + "Early CSE w/ MemorySSA", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) +INITIALIZE_PASS_END(EarlyCSEMemSSALegacyPass, "early-cse-memssa", + "Early CSE w/ MemorySSA", false, false) diff --git a/contrib/llvm/lib/Transforms/Scalar/Float2Int.cpp b/contrib/llvm/lib/Transforms/Scalar/Float2Int.cpp index 7aa6dc6..545036d 100644 --- a/contrib/llvm/lib/Transforms/Scalar/Float2Int.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/Float2Int.cpp @@ -190,21 +190,14 @@ void Float2IntPass::walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots) { seen(I, badRange()); break; - case Instruction::UIToFP: { - // Path terminated cleanly. - unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits(); - APInt Min = APInt::getMinValue(BW).zextOrSelf(MaxIntegerBW+1); - APInt Max = APInt::getMaxValue(BW).zextOrSelf(MaxIntegerBW+1); - seen(I, validateRange(ConstantRange(Min, Max))); - continue; - } - + case Instruction::UIToFP: case Instruction::SIToFP: { - // Path terminated cleanly. + // Path terminated cleanly - use the type of the integer input to seed + // the analysis. unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits(); - APInt SMin = APInt::getSignedMinValue(BW).sextOrSelf(MaxIntegerBW+1); - APInt SMax = APInt::getSignedMaxValue(BW).sextOrSelf(MaxIntegerBW+1); - seen(I, validateRange(ConstantRange(SMin, SMax))); + auto Input = ConstantRange(BW, true); + auto CastOp = (Instruction::CastOps)I->getOpcode(); + seen(I, validateRange(Input.castOp(CastOp, MaxIntegerBW+1))); continue; } @@ -249,23 +242,12 @@ void Float2IntPass::walkForwards() { llvm_unreachable("Should have been handled in walkForwards!"); case Instruction::FAdd: - Op = [](ArrayRef<ConstantRange> Ops) { - assert(Ops.size() == 2 && "FAdd is a binary operator!"); - return Ops[0].add(Ops[1]); - }; - break; - case Instruction::FSub: - Op = [](ArrayRef<ConstantRange> Ops) { - assert(Ops.size() == 2 && "FSub is a binary operator!"); - return Ops[0].sub(Ops[1]); - }; - break; - case Instruction::FMul: - Op = [](ArrayRef<ConstantRange> Ops) { - assert(Ops.size() == 2 && "FMul is a binary operator!"); - return Ops[0].multiply(Ops[1]); + Op = [I](ArrayRef<ConstantRange> Ops) { + assert(Ops.size() == 2 && "its a binary operator!"); + auto BinOp = (Instruction::BinaryOps) I->getOpcode(); + return Ops[0].binaryOp(BinOp, Ops[1]); }; break; @@ -275,9 +257,12 @@ void Float2IntPass::walkForwards() { // case Instruction::FPToUI: case Instruction::FPToSI: - Op = [](ArrayRef<ConstantRange> Ops) { + Op = [I](ArrayRef<ConstantRange> Ops) { assert(Ops.size() == 1 && "FPTo[US]I is a unary operator!"); - return Ops[0]; + // Note: We're ignoring the casts output size here as that's what the + // caller expects. + auto CastOp = (Instruction::CastOps)I->getOpcode(); + return Ops[0].castOp(CastOp, MaxIntegerBW+1); }; break; diff --git a/contrib/llvm/lib/Transforms/Scalar/GVN.cpp b/contrib/llvm/lib/Transforms/Scalar/GVN.cpp index a35a106..0137378 100644 --- a/contrib/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/GVN.cpp @@ -33,6 +33,7 @@ #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" #include "llvm/Analysis/PHITransAddr.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -338,16 +339,9 @@ GVN::Expression GVN::ValueTable::createExtractvalueExpr(ExtractValueInst *EI) { //===----------------------------------------------------------------------===// GVN::ValueTable::ValueTable() : nextValueNumber(1) {} -GVN::ValueTable::ValueTable(const ValueTable &Arg) - : valueNumbering(Arg.valueNumbering), - expressionNumbering(Arg.expressionNumbering), AA(Arg.AA), MD(Arg.MD), - DT(Arg.DT), nextValueNumber(Arg.nextValueNumber) {} -GVN::ValueTable::ValueTable(ValueTable &&Arg) - : valueNumbering(std::move(Arg.valueNumbering)), - expressionNumbering(std::move(Arg.expressionNumbering)), - AA(std::move(Arg.AA)), MD(std::move(Arg.MD)), DT(std::move(Arg.DT)), - nextValueNumber(std::move(Arg.nextValueNumber)) {} -GVN::ValueTable::~ValueTable() {} +GVN::ValueTable::ValueTable(const ValueTable &) = default; +GVN::ValueTable::ValueTable(ValueTable &&) = default; +GVN::ValueTable::~ValueTable() = default; /// add - Insert a value into the table with a specified value number. void GVN::ValueTable::add(Value *V, uint32_t num) { @@ -583,7 +577,7 @@ void GVN::ValueTable::verifyRemoved(const Value *V) const { // GVN Pass //===----------------------------------------------------------------------===// -PreservedAnalyses GVN::run(Function &F, AnalysisManager<Function> &AM) { +PreservedAnalyses GVN::run(Function &F, FunctionAnalysisManager &AM) { // FIXME: The order of evaluation of these 'getResult' calls is very // significant! Re-ordering these variables will cause GVN when run alone to // be less effective! We should fix memdep and basic-aa to not exhibit this @@ -593,7 +587,9 @@ PreservedAnalyses GVN::run(Function &F, AnalysisManager<Function> &AM) { auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); auto &MemDep = AM.getResult<MemoryDependenceAnalysis>(F); - bool Changed = runImpl(F, AC, DT, TLI, AA, &MemDep); + auto *LI = AM.getCachedResult<LoopAnalysis>(F); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + bool Changed = runImpl(F, AC, DT, TLI, AA, &MemDep, LI, &ORE); if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; @@ -725,8 +721,9 @@ static Value *CoerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy, assert(CanCoerceMustAliasedValueToLoad(StoredVal, LoadedTy, DL) && "precondition violation - materialization can't fail"); - if (auto *CExpr = dyn_cast<ConstantExpr>(StoredVal)) - StoredVal = ConstantFoldConstantExpression(CExpr, DL); + if (auto *C = dyn_cast<Constant>(StoredVal)) + if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL)) + StoredVal = FoldedStoredVal; // If this is already the right type, just return it. Type *StoredValTy = StoredVal->getType(); @@ -759,8 +756,9 @@ static Value *CoerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy, StoredVal = IRB.CreateIntToPtr(StoredVal, LoadedTy); } - if (auto *CExpr = dyn_cast<ConstantExpr>(StoredVal)) - StoredVal = ConstantFoldConstantExpression(CExpr, DL); + if (auto *C = dyn_cast<ConstantExpr>(StoredVal)) + if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL)) + StoredVal = FoldedStoredVal; return StoredVal; } @@ -804,8 +802,9 @@ static Value *CoerceAvailableValueToLoadType(Value *StoredVal, Type *LoadedTy, StoredVal = IRB.CreateBitCast(StoredVal, LoadedTy, "bitcast"); } - if (auto *CExpr = dyn_cast<ConstantExpr>(StoredVal)) - StoredVal = ConstantFoldConstantExpression(CExpr, DL); + if (auto *C = dyn_cast<Constant>(StoredVal)) + if (auto *FoldedStoredVal = ConstantFoldConstant(C, DL)) + StoredVal = FoldedStoredVal; return StoredVal; } @@ -838,16 +837,6 @@ static int AnalyzeLoadFromClobberingWrite(Type *LoadTy, Value *LoadPtr, // a must alias. AA must have gotten confused. // FIXME: Study to see if/when this happens. One case is forwarding a memset // to a load from the base of the memset. -#if 0 - if (LoadOffset == StoreOffset) { - dbgs() << "STORE/LOAD DEP WITH COMMON POINTER MISSED:\n" - << "Base = " << *StoreBase << "\n" - << "Store Ptr = " << *WritePtr << "\n" - << "Store Offs = " << StoreOffset << "\n" - << "Load Ptr = " << *LoadPtr << "\n"; - abort(); - } -#endif // If the load and store don't overlap at all, the store doesn't provide // anything to the load. In this case, they really don't alias at all, AA @@ -856,8 +845,8 @@ static int AnalyzeLoadFromClobberingWrite(Type *LoadTy, Value *LoadPtr, if ((WriteSizeInBits & 7) | (LoadSize & 7)) return -1; - uint64_t StoreSize = WriteSizeInBits >> 3; // Convert to bytes. - LoadSize >>= 3; + uint64_t StoreSize = WriteSizeInBits / 8; // Convert to bytes. + LoadSize /= 8; bool isAAFailure = false; @@ -866,17 +855,8 @@ static int AnalyzeLoadFromClobberingWrite(Type *LoadTy, Value *LoadPtr, else isAAFailure = LoadOffset+int64_t(LoadSize) <= StoreOffset; - if (isAAFailure) { -#if 0 - dbgs() << "STORE LOAD DEP WITH COMMON BASE:\n" - << "Base = " << *StoreBase << "\n" - << "Store Ptr = " << *WritePtr << "\n" - << "Store Offs = " << StoreOffset << "\n" - << "Load Ptr = " << *LoadPtr << "\n"; - abort(); -#endif + if (isAAFailure) return -1; - } // If the Load isn't completely contained within the stored bits, we don't // have all the bits to feed it. We could do something crazy in the future @@ -1229,6 +1209,38 @@ static bool isLifetimeStart(const Instruction *Inst) { return false; } +/// \brief Try to locate the three instruction involved in a missed +/// load-elimination case that is due to an intervening store. +static void reportMayClobberedLoad(LoadInst *LI, MemDepResult DepInfo, + DominatorTree *DT, + OptimizationRemarkEmitter *ORE) { + using namespace ore; + User *OtherAccess = nullptr; + + OptimizationRemarkMissed R(DEBUG_TYPE, "LoadClobbered", LI); + R << "load of type " << NV("Type", LI->getType()) << " not eliminated" + << setExtraArgs(); + + for (auto *U : LI->getPointerOperand()->users()) + if (U != LI && (isa<LoadInst>(U) || isa<StoreInst>(U)) && + DT->dominates(cast<Instruction>(U), LI)) { + // FIXME: for now give up if there are multiple memory accesses that + // dominate the load. We need further analysis to decide which one is + // that we're forwarding from. + if (OtherAccess) + OtherAccess = nullptr; + else + OtherAccess = U; + } + + if (OtherAccess) + R << " in favor of " << NV("OtherAccess", OtherAccess); + + R << " because it is clobbered by " << NV("ClobberedBy", DepInfo.getInst()); + + ORE->emit(R); +} + bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, Value *Address, AvailableValue &Res) { @@ -1293,6 +1305,10 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, Instruction *I = DepInfo.getInst(); dbgs() << " is clobbered by " << *I << '\n'; ); + + if (ORE->allowExtraAnalysis()) + reportMayClobberedLoad(LI, DepInfo, DT, ORE); + return false; } assert(DepInfo.isDef() && "follows from above"); @@ -1556,6 +1572,13 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, // Assign value numbers to the new instructions. for (Instruction *I : NewInsts) { + // Instructions that have been inserted in predecessor(s) to materialize + // the load address do not retain their original debug locations. Doing + // so could lead to confusing (but correct) source attributions. + // FIXME: How do we retain source locations without causing poor debugging + // behavior? + I->setDebugLoc(DebugLoc()); + // FIXME: We really _ought_ to insert these value numbers into their // parent's availability map. However, in doing so, we risk getting into // ordering issues. If a block hasn't been processed yet, we would be @@ -1585,8 +1608,11 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, if (auto *RangeMD = LI->getMetadata(LLVMContext::MD_range)) NewLoad->setMetadata(LLVMContext::MD_range, RangeMD); - // Transfer DebugLoc. - NewLoad->setDebugLoc(LI->getDebugLoc()); + // We do not propagate the old load's debug location, because the new + // load now lives in a different BB, and we want to avoid a jumpy line + // table. + // FIXME: How do we retain source locations without causing poor debugging + // behavior? // Add the newly created load. ValuesPerBlock.push_back(AvailableValueInBlock::get(UnavailablePred, @@ -1605,10 +1631,21 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, if (V->getType()->getScalarType()->isPointerTy()) MD->invalidateCachedPointerInfo(V); markInstructionForDeletion(LI); + ORE->emit(OptimizationRemark(DEBUG_TYPE, "LoadPRE", LI) + << "load eliminated by PRE"); ++NumPRELoad; return true; } +static void reportLoadElim(LoadInst *LI, Value *AvailableValue, + OptimizationRemarkEmitter *ORE) { + using namespace ore; + ORE->emit(OptimizationRemark(DEBUG_TYPE, "LoadElim", LI) + << "load of type " << NV("Type", LI->getType()) << " eliminated" + << setExtraArgs() << " in favor of " + << NV("InfavorOfValue", AvailableValue)); +} + /// Attempt to eliminate a load whose dependencies are /// non-local by performing PHI construction. bool GVN::processNonLocalLoad(LoadInst *LI) { @@ -1673,12 +1710,16 @@ bool GVN::processNonLocalLoad(LoadInst *LI) { if (isa<PHINode>(V)) V->takeName(LI); if (Instruction *I = dyn_cast<Instruction>(V)) - if (LI->getDebugLoc()) + // If instruction I has debug info, then we should not update it. + // Also, if I has a null DebugLoc, then it is still potentially incorrect + // to propagate LI's DebugLoc because LI may not post-dominate I. + if (LI->getDebugLoc() && ValuesPerBlock.size() != 1) I->setDebugLoc(LI->getDebugLoc()); if (V->getType()->getScalarType()->isPointerTy()) MD->invalidateCachedPointerInfo(V); markInstructionForDeletion(LI); ++NumGVNLoad; + reportLoadElim(LI, V, ORE); return true; } @@ -1754,7 +1795,12 @@ static void patchReplacementInstruction(Instruction *I, Value *Repl) { // Patch the replacement so that it is not more restrictive than the value // being replaced. - ReplInst->andIRFlags(I); + // Note that if 'I' is a load being replaced by some operation, + // for example, by an arithmetic operation, then andIRFlags() + // would just erase all math flags from the original arithmetic + // operation, which is clearly not wanted and not needed. + if (!isa<LoadInst>(I)) + ReplInst->andIRFlags(I); // FIXME: If both the original and replacement value are part of the // same control-flow region (meaning that the execution of one @@ -1820,6 +1866,7 @@ bool GVN::processLoad(LoadInst *L) { patchAndReplaceAllUsesWith(L, AvailableValue); markInstructionForDeletion(L); ++NumGVNLoad; + reportLoadElim(L, AvailableValue, ORE); // Tell MDA to rexamine the reused pointer since we might have more // information after forwarding it. if (MD && AvailableValue->getType()->getScalarType()->isPointerTy()) @@ -2197,7 +2244,8 @@ bool GVN::processInstruction(Instruction *I) { /// runOnFunction - This is the main transformation entry point for a function. bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, const TargetLibraryInfo &RunTLI, AAResults &RunAA, - MemoryDependenceResults *RunMD) { + MemoryDependenceResults *RunMD, LoopInfo *LI, + OptimizationRemarkEmitter *RunORE) { AC = &RunAC; DT = &RunDT; VN.setDomTree(DT); @@ -2205,6 +2253,7 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, VN.setAliasAnalysis(&RunAA); MD = RunMD; VN.setMemDep(MD); + ORE = RunORE; bool Changed = false; bool ShouldContinue = true; @@ -2214,9 +2263,9 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE; ) { BasicBlock *BB = &*FI++; - bool removedBlock = - MergeBlockIntoPredecessor(BB, DT, /* LoopInfo */ nullptr, MD); - if (removedBlock) ++NumGVNBlocks; + bool removedBlock = MergeBlockIntoPredecessor(BB, DT, LI, MD); + if (removedBlock) + ++NumGVNBlocks; Changed |= removedBlock; } @@ -2711,13 +2760,17 @@ public: if (skipFunction(F)) return false; + auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); + return Impl.runImpl( F, getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), getAnalysis<DominatorTreeWrapperPass>().getDomTree(), getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), getAnalysis<AAResultsWrapperPass>().getAAResults(), NoLoads ? nullptr - : &getAnalysis<MemoryDependenceWrapperPass>().getMemDep()); + : &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(), + LIWP ? &LIWP->getLoopInfo() : nullptr, + &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE()); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -2730,6 +2783,7 @@ public: AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); } private: @@ -2751,4 +2805,5 @@ INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(GVNLegacyPass, "gvn", "Global Value Numbering", false, false) diff --git a/contrib/llvm/lib/Transforms/Scalar/GVNHoist.cpp b/contrib/llvm/lib/Transforms/Scalar/GVNHoist.cpp index cce1db3..f8e1d2e 100644 --- a/contrib/llvm/lib/Transforms/Scalar/GVNHoist.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/GVNHoist.cpp @@ -9,20 +9,23 @@ // // This pass hoists expressions from branches to a common dominator. It uses // GVN (global value numbering) to discover expressions computing the same -// values. The primary goal is to reduce the code size, and in some -// cases reduce critical path (by exposing more ILP). +// values. The primary goals of code-hoisting are: +// 1. To reduce the code size. +// 2. In some cases reduce critical path (by exposing more ILP). +// // Hoisting may affect the performance in some cases. To mitigate that, hoisting // is disabled in the following cases. // 1. Scalars across calls. // 2. geps when corresponding load/store cannot be hoisted. //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/GVN.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/MemorySSA.h" using namespace llvm; @@ -47,15 +50,25 @@ static cl::opt<int> MaxNumberOfBBSInPath( cl::desc("Max number of basic blocks on the path between " "hoisting locations (default = 4, unlimited = -1)")); +static cl::opt<int> MaxDepthInBB( + "gvn-hoist-max-depth", cl::Hidden, cl::init(100), + cl::desc("Hoist instructions from the beginning of the BB up to the " + "maximum specified depth (default = 100, unlimited = -1)")); + +static cl::opt<int> + MaxChainLength("gvn-hoist-max-chain-length", cl::Hidden, cl::init(10), + cl::desc("Maximum length of dependent chains to hoist " + "(default = 10, unlimited = -1)")); + namespace { // Provides a sorting function based on the execution order of two instructions. struct SortByDFSIn { private: - DenseMap<const BasicBlock *, unsigned> &DFSNumber; + DenseMap<const Value *, unsigned> &DFSNumber; public: - SortByDFSIn(DenseMap<const BasicBlock *, unsigned> &D) : DFSNumber(D) {} + SortByDFSIn(DenseMap<const Value *, unsigned> &D) : DFSNumber(D) {} // Returns true when A executes before B. bool operator()(const Instruction *A, const Instruction *B) const { @@ -68,16 +81,16 @@ public: const BasicBlock *BA = A->getParent(); const BasicBlock *BB = B->getParent(); - unsigned NA = DFSNumber[BA]; - unsigned NB = DFSNumber[BB]; - if (NA < NB) - return true; - if (NA == NB) { - // Sort them in the order they occur in the same basic block. - BasicBlock::const_iterator AI(A), BI(B); - return std::distance(AI, BI) < 0; + unsigned ADFS, BDFS; + if (BA == BB) { + ADFS = DFSNumber.lookup(A); + BDFS = DFSNumber.lookup(B); + } else { + ADFS = DFSNumber.lookup(BA); + BDFS = DFSNumber.lookup(BB); } - return false; + assert(ADFS && BDFS); + return ADFS < BDFS; } }; @@ -172,27 +185,77 @@ typedef DenseMap<const BasicBlock *, bool> BBSideEffectsSet; typedef SmallVector<Instruction *, 4> SmallVecInsn; typedef SmallVectorImpl<Instruction *> SmallVecImplInsn; +static void combineKnownMetadata(Instruction *ReplInst, Instruction *I) { + static const unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, LLVMContext::MD_range, + LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, + LLVMContext::MD_invariant_group}; + combineMetadata(ReplInst, I, KnownIDs); +} + // This pass hoists common computations across branches sharing common // dominator. The primary goal is to reduce the code size, and in some // cases reduce critical path (by exposing more ILP). class GVNHoist { public: + GVNHoist(DominatorTree *DT, AliasAnalysis *AA, MemoryDependenceResults *MD, + MemorySSA *MSSA) + : DT(DT), AA(AA), MD(MD), MSSA(MSSA), + HoistingGeps(false), + HoistedCtr(0) + { } + + bool run(Function &F) { + VN.setDomTree(DT); + VN.setAliasAnalysis(AA); + VN.setMemDep(MD); + bool Res = false; + // Perform DFS Numbering of instructions. + unsigned BBI = 0; + for (const BasicBlock *BB : depth_first(&F.getEntryBlock())) { + DFSNumber[BB] = ++BBI; + unsigned I = 0; + for (auto &Inst : *BB) + DFSNumber[&Inst] = ++I; + } + + int ChainLength = 0; + + // FIXME: use lazy evaluation of VN to avoid the fix-point computation. + while (1) { + if (MaxChainLength != -1 && ++ChainLength >= MaxChainLength) + return Res; + + auto HoistStat = hoistExpressions(F); + if (HoistStat.first + HoistStat.second == 0) + return Res; + + if (HoistStat.second > 0) + // To address a limitation of the current GVN, we need to rerun the + // hoisting after we hoisted loads or stores in order to be able to + // hoist all scalars dependent on the hoisted ld/st. + VN.clear(); + + Res = true; + } + + return Res; + } + +private: GVN::ValueTable VN; DominatorTree *DT; AliasAnalysis *AA; MemoryDependenceResults *MD; - const bool OptForMinSize; - DenseMap<const BasicBlock *, unsigned> DFSNumber; - BBSideEffectsSet BBSideEffects; MemorySSA *MSSA; + const bool HoistingGeps; + DenseMap<const Value *, unsigned> DFSNumber; + BBSideEffectsSet BBSideEffects; int HoistedCtr; enum InsKind { Unknown, Scalar, Load, Store }; - GVNHoist(DominatorTree *Dt, AliasAnalysis *Aa, MemoryDependenceResults *Md, - bool OptForMinSize) - : DT(Dt), AA(Aa), MD(Md), OptForMinSize(OptForMinSize), HoistedCtr(0) {} - // Return true when there are exception handling in BB. bool hasEH(const BasicBlock *BB) { auto It = BBSideEffects.find(BB); @@ -213,24 +276,32 @@ public: return false; } - // Return true when all paths from A to the end of the function pass through - // either B or C. - bool hoistingFromAllPaths(const BasicBlock *A, const BasicBlock *B, - const BasicBlock *C) { - // We fully copy the WL in order to be able to remove items from it. - SmallPtrSet<const BasicBlock *, 2> WL; - WL.insert(B); - WL.insert(C); - - for (auto It = df_begin(A), E = df_end(A); It != E;) { - // There exists a path from A to the exit of the function if we are still - // iterating in DF traversal and we removed all instructions from the work - // list. - if (WL.empty()) + // Return true when a successor of BB dominates A. + bool successorDominate(const BasicBlock *BB, const BasicBlock *A) { + for (const BasicBlock *Succ : BB->getTerminator()->successors()) + if (DT->dominates(Succ, A)) + return true; + + return false; + } + + // Return true when all paths from HoistBB to the end of the function pass + // through one of the blocks in WL. + bool hoistingFromAllPaths(const BasicBlock *HoistBB, + SmallPtrSetImpl<const BasicBlock *> &WL) { + + // Copy WL as the loop will remove elements from it. + SmallPtrSet<const BasicBlock *, 2> WorkList(WL.begin(), WL.end()); + + for (auto It = df_begin(HoistBB), E = df_end(HoistBB); It != E;) { + // There exists a path from HoistBB to the exit of the function if we are + // still iterating in DF traversal and we removed all instructions from + // the work list. + if (WorkList.empty()) return false; const BasicBlock *BB = *It; - if (WL.erase(BB)) { + if (WorkList.erase(BB)) { // Stop DFS traversal when BB is in the work list. It.skipChildren(); continue; @@ -240,6 +311,11 @@ public: if (!isGuaranteedToTransferExecutionToSuccessor(BB->getTerminator())) return false; + // When reaching the back-edge of a loop, there may be a path through the + // loop that does not pass through B or C before exiting the loop. + if (successorDominate(BB, HoistBB)) + return false; + // Increment DFS traversal when not skipping children. ++It; } @@ -248,40 +324,43 @@ public: } /* Return true when I1 appears before I2 in the instructions of BB. */ - bool firstInBB(BasicBlock *BB, const Instruction *I1, const Instruction *I2) { - for (Instruction &I : *BB) { - if (&I == I1) - return true; - if (&I == I2) - return false; - } - - llvm_unreachable("I1 and I2 not found in BB"); + bool firstInBB(const Instruction *I1, const Instruction *I2) { + assert(I1->getParent() == I2->getParent()); + unsigned I1DFS = DFSNumber.lookup(I1); + unsigned I2DFS = DFSNumber.lookup(I2); + assert(I1DFS && I2DFS); + return I1DFS < I2DFS; } - // Return true when there are users of Def in BB. - bool hasMemoryUseOnPath(MemoryAccess *Def, const BasicBlock *BB, - const Instruction *OldPt) { - const BasicBlock *DefBB = Def->getBlock(); - const BasicBlock *OldBB = OldPt->getParent(); - for (User *U : Def->users()) - if (auto *MU = dyn_cast<MemoryUse>(U)) { - BasicBlock *UBB = MU->getBlock(); - // Only analyze uses in BB. - if (BB != UBB) - continue; + // Return true when there are memory uses of Def in BB. + bool hasMemoryUse(const Instruction *NewPt, MemoryDef *Def, + const BasicBlock *BB) { + const MemorySSA::AccessList *Acc = MSSA->getBlockAccesses(BB); + if (!Acc) + return false; - // A use in the same block as the Def is on the path. - if (UBB == DefBB) { - assert(MSSA->locallyDominates(Def, MU) && "def not dominating use"); - return true; - } + Instruction *OldPt = Def->getMemoryInst(); + const BasicBlock *OldBB = OldPt->getParent(); + const BasicBlock *NewBB = NewPt->getParent(); + bool ReachedNewPt = false; - if (UBB != OldBB) - return true; + for (const MemoryAccess &MA : *Acc) + if (const MemoryUse *MU = dyn_cast<MemoryUse>(&MA)) { + Instruction *Insn = MU->getMemoryInst(); + + // Do not check whether MU aliases Def when MU occurs after OldPt. + if (BB == OldBB && firstInBB(OldPt, Insn)) + break; - // It is only harmful to hoist when the use is before OldPt. - if (firstInBB(UBB, MU->getMemoryInst(), OldPt)) + // Do not check whether MU aliases Def when MU occurs before NewPt. + if (BB == NewBB) { + if (!ReachedNewPt) { + if (firstInBB(Insn, NewPt)) + continue; + ReachedNewPt = true; + } + } + if (defClobbersUseOrDef(Def, MU, *AA)) return true; } @@ -289,17 +368,18 @@ public: } // Return true when there are exception handling or loads of memory Def - // between OldPt and NewPt. + // between Def and NewPt. This function is only called for stores: Def is + // the MemoryDef of the store to be hoisted. // Decrement by 1 NBBsOnAllPaths for each block between HoistPt and BB, and // return true when the counter NBBsOnAllPaths reaces 0, except when it is // initialized to -1 which is unlimited. - bool hasEHOrLoadsOnPath(const Instruction *NewPt, const Instruction *OldPt, - MemoryAccess *Def, int &NBBsOnAllPaths) { + bool hasEHOrLoadsOnPath(const Instruction *NewPt, MemoryDef *Def, + int &NBBsOnAllPaths) { const BasicBlock *NewBB = NewPt->getParent(); - const BasicBlock *OldBB = OldPt->getParent(); + const BasicBlock *OldBB = Def->getBlock(); assert(DT->dominates(NewBB, OldBB) && "invalid path"); - assert(DT->dominates(Def->getBlock(), NewBB) && + assert(DT->dominates(Def->getDefiningAccess()->getBlock(), NewBB) && "def does not dominate new hoisting point"); // Walk all basic blocks reachable in depth-first iteration on the inverse @@ -313,16 +393,16 @@ public: continue; } + // Stop walk once the limit is reached. + if (NBBsOnAllPaths == 0) + return true; + // Impossible to hoist with exceptions on the path. if (hasEH(*I)) return true; // Check that we do not move a store past loads. - if (hasMemoryUseOnPath(Def, *I, OldPt)) - return true; - - // Stop walk once the limit is reached. - if (NBBsOnAllPaths == 0) + if (hasMemoryUse(NewPt, Def, *I)) return true; // -1 is unlimited number of blocks on all paths. @@ -355,14 +435,14 @@ public: continue; } - // Impossible to hoist with exceptions on the path. - if (hasEH(*I)) - return true; - // Stop walk once the limit is reached. if (NBBsOnAllPaths == 0) return true; + // Impossible to hoist with exceptions on the path. + if (hasEH(*I)) + return true; + // -1 is unlimited number of blocks on all paths. if (NBBsOnAllPaths != -1) --NBBsOnAllPaths; @@ -395,13 +475,13 @@ public: if (NewBB == DBB && !MSSA->isLiveOnEntryDef(D)) if (auto *UD = dyn_cast<MemoryUseOrDef>(D)) - if (firstInBB(DBB, NewPt, UD->getMemoryInst())) + if (firstInBB(NewPt, UD->getMemoryInst())) // Cannot move the load or store to NewPt above its definition in D. return false; // Check for unsafe hoistings due to side effects. if (K == InsKind::Store) { - if (hasEHOrLoadsOnPath(NewPt, OldPt, D, NBBsOnAllPaths)) + if (hasEHOrLoadsOnPath(NewPt, dyn_cast<MemoryDef>(U), NBBsOnAllPaths)) return false; } else if (hasEHOnPath(NewBB, OldBB, NBBsOnAllPaths)) return false; @@ -417,23 +497,19 @@ public: return true; } - // Return true when it is safe to hoist scalar instructions from BB1 and BB2 - // to HoistBB. - bool safeToHoistScalar(const BasicBlock *HoistBB, const BasicBlock *BB1, - const BasicBlock *BB2, int &NBBsOnAllPaths) { - // Check that the hoisted expression is needed on all paths. When HoistBB - // already contains an instruction to be hoisted, the expression is needed - // on all paths. Enable scalar hoisting at -Oz as it is safe to hoist - // scalars to a place where they are partially needed. - if (!OptForMinSize && BB1 != HoistBB && - !hoistingFromAllPaths(HoistBB, BB1, BB2)) + // Return true when it is safe to hoist scalar instructions from all blocks in + // WL to HoistBB. + bool safeToHoistScalar(const BasicBlock *HoistBB, + SmallPtrSetImpl<const BasicBlock *> &WL, + int &NBBsOnAllPaths) { + // Check that the hoisted expression is needed on all paths. + if (!hoistingFromAllPaths(HoistBB, WL)) return false; - if (hasEHOnPath(HoistBB, BB1, NBBsOnAllPaths) || - hasEHOnPath(HoistBB, BB2, NBBsOnAllPaths)) - return false; + for (const BasicBlock *BB : WL) + if (hasEHOnPath(HoistBB, BB, NBBsOnAllPaths)) + return false; - // Safe to hoist scalars from BB1 and BB2 to HoistBB. return true; } @@ -454,7 +530,7 @@ public: std::sort(InstructionsToHoist.begin(), InstructionsToHoist.end(), Pred); } - int NBBsOnAllPaths = MaxNumberOfBBSInPath; + int NumBBsOnAllPaths = MaxNumberOfBBSInPath; SmallVecImplInsn::iterator II = InstructionsToHoist.begin(); SmallVecImplInsn::iterator Start = II; @@ -462,7 +538,7 @@ public: BasicBlock *HoistBB = HoistPt->getParent(); MemoryUseOrDef *UD; if (K != InsKind::Scalar) - UD = cast<MemoryUseOrDef>(MSSA->getMemoryAccess(HoistPt)); + UD = MSSA->getMemoryAccess(HoistPt); for (++II; II != InstructionsToHoist.end(); ++II) { Instruction *Insn = *II; @@ -470,10 +546,12 @@ public: BasicBlock *NewHoistBB; Instruction *NewHoistPt; - if (BB == HoistBB) { + if (BB == HoistBB) { // Both are in the same Basic Block. NewHoistBB = HoistBB; - NewHoistPt = firstInBB(BB, Insn, HoistPt) ? Insn : HoistPt; + NewHoistPt = firstInBB(Insn, HoistPt) ? Insn : HoistPt; } else { + // If the hoisting point contains one of the instructions, + // then hoist there, otherwise hoist before the terminator. NewHoistBB = DT->findNearestCommonDominator(HoistBB, BB); if (NewHoistBB == BB) NewHoistPt = Insn; @@ -483,8 +561,12 @@ public: NewHoistPt = NewHoistBB->getTerminator(); } + SmallPtrSet<const BasicBlock *, 2> WL; + WL.insert(HoistBB); + WL.insert(BB); + if (K == InsKind::Scalar) { - if (safeToHoistScalar(NewHoistBB, HoistBB, BB, NBBsOnAllPaths)) { + if (safeToHoistScalar(NewHoistBB, WL, NumBBsOnAllPaths)) { // Extend HoistPt to NewHoistPt. HoistPt = NewHoistPt; HoistBB = NewHoistBB; @@ -498,13 +580,12 @@ public: // loading from the same address: for instance there may be a branch on // which the address of the load may not be initialized. if ((HoistBB == NewHoistBB || BB == NewHoistBB || - hoistingFromAllPaths(NewHoistBB, HoistBB, BB)) && + hoistingFromAllPaths(NewHoistBB, WL)) && // Also check that it is safe to move the load or store from HoistPt // to NewHoistPt, and from Insn to NewHoistPt. - safeToHoistLdSt(NewHoistPt, HoistPt, UD, K, NBBsOnAllPaths) && - safeToHoistLdSt(NewHoistPt, Insn, - cast<MemoryUseOrDef>(MSSA->getMemoryAccess(Insn)), - K, NBBsOnAllPaths)) { + safeToHoistLdSt(NewHoistPt, HoistPt, UD, K, NumBBsOnAllPaths) && + safeToHoistLdSt(NewHoistPt, Insn, MSSA->getMemoryAccess(Insn), + K, NumBBsOnAllPaths)) { // Extend HoistPt to NewHoistPt. HoistPt = NewHoistPt; HoistBB = NewHoistBB; @@ -520,10 +601,10 @@ public: // Start over from BB. Start = II; if (K != InsKind::Scalar) - UD = cast<MemoryUseOrDef>(MSSA->getMemoryAccess(*Start)); + UD = MSSA->getMemoryAccess(*Start); HoistPt = Insn; HoistBB = BB; - NBBsOnAllPaths = MaxNumberOfBBSInPath; + NumBBsOnAllPaths = MaxNumberOfBBSInPath; } // Save the last partition. @@ -567,40 +648,88 @@ public: return true; } - Instruction *firstOfTwo(Instruction *I, Instruction *J) const { - for (Instruction &I1 : *I->getParent()) - if (&I1 == I || &I1 == J) - return &I1; - llvm_unreachable("Both I and J must be from same BB"); + // Same as allOperandsAvailable with recursive check for GEP operands. + bool allGepOperandsAvailable(const Instruction *I, + const BasicBlock *HoistPt) const { + for (const Use &Op : I->operands()) + if (const auto *Inst = dyn_cast<Instruction>(&Op)) + if (!DT->dominates(Inst->getParent(), HoistPt)) { + if (const GetElementPtrInst *GepOp = + dyn_cast<GetElementPtrInst>(Inst)) { + if (!allGepOperandsAvailable(GepOp, HoistPt)) + return false; + // Gep is available if all operands of GepOp are available. + } else { + // Gep is not available if it has operands other than GEPs that are + // defined in blocks not dominating HoistPt. + return false; + } + } + return true; } - // Replace the use of From with To in Insn. - void replaceUseWith(Instruction *Insn, Value *From, Value *To) const { - for (Value::use_iterator UI = From->use_begin(), UE = From->use_end(); - UI != UE;) { - Use &U = *UI++; - if (U.getUser() == Insn) { - U.set(To); - return; + // Make all operands of the GEP available. + void makeGepsAvailable(Instruction *Repl, BasicBlock *HoistPt, + const SmallVecInsn &InstructionsToHoist, + Instruction *Gep) const { + assert(allGepOperandsAvailable(Gep, HoistPt) && + "GEP operands not available"); + + Instruction *ClonedGep = Gep->clone(); + for (unsigned i = 0, e = Gep->getNumOperands(); i != e; ++i) + if (Instruction *Op = dyn_cast<Instruction>(Gep->getOperand(i))) { + + // Check whether the operand is already available. + if (DT->dominates(Op->getParent(), HoistPt)) + continue; + + // As a GEP can refer to other GEPs, recursively make all the operands + // of this GEP available at HoistPt. + if (GetElementPtrInst *GepOp = dyn_cast<GetElementPtrInst>(Op)) + makeGepsAvailable(ClonedGep, HoistPt, InstructionsToHoist, GepOp); } + + // Copy Gep and replace its uses in Repl with ClonedGep. + ClonedGep->insertBefore(HoistPt->getTerminator()); + + // Conservatively discard any optimization hints, they may differ on the + // other paths. + ClonedGep->dropUnknownNonDebugMetadata(); + + // If we have optimization hints which agree with each other along different + // paths, preserve them. + for (const Instruction *OtherInst : InstructionsToHoist) { + const GetElementPtrInst *OtherGep; + if (auto *OtherLd = dyn_cast<LoadInst>(OtherInst)) + OtherGep = cast<GetElementPtrInst>(OtherLd->getPointerOperand()); + else + OtherGep = cast<GetElementPtrInst>( + cast<StoreInst>(OtherInst)->getPointerOperand()); + ClonedGep->andIRFlags(OtherGep); } - llvm_unreachable("should replace exactly once"); + + // Replace uses of Gep with ClonedGep in Repl. + Repl->replaceUsesOfWith(Gep, ClonedGep); } - bool makeOperandsAvailable(Instruction *Repl, BasicBlock *HoistPt) const { + // In the case Repl is a load or a store, we make all their GEPs + // available: GEPs are not hoisted by default to avoid the address + // computations to be hoisted without the associated load or store. + bool makeGepOperandsAvailable(Instruction *Repl, BasicBlock *HoistPt, + const SmallVecInsn &InstructionsToHoist) const { // Check whether the GEP of a ld/st can be synthesized at HoistPt. GetElementPtrInst *Gep = nullptr; Instruction *Val = nullptr; - if (auto *Ld = dyn_cast<LoadInst>(Repl)) + if (auto *Ld = dyn_cast<LoadInst>(Repl)) { Gep = dyn_cast<GetElementPtrInst>(Ld->getPointerOperand()); - if (auto *St = dyn_cast<StoreInst>(Repl)) { + } else if (auto *St = dyn_cast<StoreInst>(Repl)) { Gep = dyn_cast<GetElementPtrInst>(St->getPointerOperand()); Val = dyn_cast<Instruction>(St->getValueOperand()); // Check that the stored value is available. if (Val) { if (isa<GetElementPtrInst>(Val)) { // Check whether we can compute the GEP at HoistPt. - if (!allOperandsAvailable(Val, HoistPt)) + if (!allGepOperandsAvailable(Val, HoistPt)) return false; } else if (!DT->dominates(Val->getParent(), HoistPt)) return false; @@ -608,20 +737,13 @@ public: } // Check whether we can compute the Gep at HoistPt. - if (!Gep || !allOperandsAvailable(Gep, HoistPt)) + if (!Gep || !allGepOperandsAvailable(Gep, HoistPt)) return false; - // Copy the gep before moving the ld/st. - Instruction *ClonedGep = Gep->clone(); - ClonedGep->insertBefore(HoistPt->getTerminator()); - replaceUseWith(Repl, Gep, ClonedGep); + makeGepsAvailable(Repl, HoistPt, InstructionsToHoist, Gep); - // Also copy Val when it is a GEP. - if (Val && isa<GetElementPtrInst>(Val)) { - Instruction *ClonedVal = Val->clone(); - ClonedVal->insertBefore(HoistPt->getTerminator()); - replaceUseWith(Repl, Val, ClonedVal); - } + if (Val && isa<GetElementPtrInst>(Val)) + makeGepsAvailable(Repl, HoistPt, InstructionsToHoist, Val); return true; } @@ -635,17 +757,21 @@ public: const SmallVecInsn &InstructionsToHoist = HP.second; Instruction *Repl = nullptr; for (Instruction *I : InstructionsToHoist) - if (I->getParent() == HoistPt) { + if (I->getParent() == HoistPt) // If there are two instructions in HoistPt to be hoisted in place: // update Repl to be the first one, such that we can rename the uses // of the second based on the first. - Repl = !Repl ? I : firstOfTwo(Repl, I); - } + if (!Repl || firstInBB(I, Repl)) + Repl = I; + // Keep track of whether we moved the instruction so we know whether we + // should move the MemoryAccess. + bool MoveAccess = true; if (Repl) { // Repl is already in HoistPt: it remains in place. assert(allOperandsAvailable(Repl, HoistPt) && "instruction depends on operands that are not available"); + MoveAccess = false; } else { // When we do not find Repl in HoistPt, select the first in the list // and move it to HoistPt. @@ -654,10 +780,39 @@ public: // We can move Repl in HoistPt only when all operands are available. // The order in which hoistings are done may influence the availability // of operands. - if (!allOperandsAvailable(Repl, HoistPt) && - !makeOperandsAvailable(Repl, HoistPt)) - continue; - Repl->moveBefore(HoistPt->getTerminator()); + if (!allOperandsAvailable(Repl, HoistPt)) { + + // When HoistingGeps there is nothing more we can do to make the + // operands available: just continue. + if (HoistingGeps) + continue; + + // When not HoistingGeps we need to copy the GEPs. + if (!makeGepOperandsAvailable(Repl, HoistPt, InstructionsToHoist)) + continue; + } + + // Move the instruction at the end of HoistPt. + Instruction *Last = HoistPt->getTerminator(); + MD->removeInstruction(Repl); + Repl->moveBefore(Last); + + DFSNumber[Repl] = DFSNumber[Last]++; + } + + MemoryAccess *NewMemAcc = MSSA->getMemoryAccess(Repl); + + if (MoveAccess) { + if (MemoryUseOrDef *OldMemAcc = + dyn_cast_or_null<MemoryUseOrDef>(NewMemAcc)) { + // The definition of this ld/st will not change: ld/st hoisting is + // legal when the ld/st is not moved past its current definition. + MemoryAccess *Def = OldMemAcc->getDefiningAccess(); + NewMemAcc = + MSSA->createMemoryAccessInBB(Repl, Def, HoistPt, MemorySSA::End); + OldMemAcc->replaceAllUsesWith(NewMemAcc); + MSSA->removeMemoryAccess(OldMemAcc); + } } if (isa<LoadInst>(Repl)) @@ -673,15 +828,54 @@ public: for (Instruction *I : InstructionsToHoist) if (I != Repl) { ++NR; - if (isa<LoadInst>(Repl)) + if (auto *ReplacementLoad = dyn_cast<LoadInst>(Repl)) { + ReplacementLoad->setAlignment( + std::min(ReplacementLoad->getAlignment(), + cast<LoadInst>(I)->getAlignment())); ++NumLoadsRemoved; - else if (isa<StoreInst>(Repl)) + } else if (auto *ReplacementStore = dyn_cast<StoreInst>(Repl)) { + ReplacementStore->setAlignment( + std::min(ReplacementStore->getAlignment(), + cast<StoreInst>(I)->getAlignment())); ++NumStoresRemoved; - else if (isa<CallInst>(Repl)) + } else if (auto *ReplacementAlloca = dyn_cast<AllocaInst>(Repl)) { + ReplacementAlloca->setAlignment( + std::max(ReplacementAlloca->getAlignment(), + cast<AllocaInst>(I)->getAlignment())); + } else if (isa<CallInst>(Repl)) { ++NumCallsRemoved; + } + + if (NewMemAcc) { + // Update the uses of the old MSSA access with NewMemAcc. + MemoryAccess *OldMA = MSSA->getMemoryAccess(I); + OldMA->replaceAllUsesWith(NewMemAcc); + MSSA->removeMemoryAccess(OldMA); + } + + Repl->andIRFlags(I); + combineKnownMetadata(Repl, I); I->replaceAllUsesWith(Repl); + // Also invalidate the Alias Analysis cache. + MD->removeInstruction(I); I->eraseFromParent(); } + + // Remove MemorySSA phi nodes with the same arguments. + if (NewMemAcc) { + SmallPtrSet<MemoryPhi *, 4> UsePhis; + for (User *U : NewMemAcc->users()) + if (MemoryPhi *Phi = dyn_cast<MemoryPhi>(U)) + UsePhis.insert(Phi); + + for (auto *Phi : UsePhis) { + auto In = Phi->incoming_values(); + if (all_of(In, [&](Use &U) { return U == NewMemAcc; })) { + Phi->replaceAllUsesWith(NewMemAcc); + MSSA->removeMemoryAccess(Phi); + } + } + } } NumHoisted += NL + NS + NC + NI; @@ -700,7 +894,17 @@ public: StoreInfo SI; CallInfo CI; for (BasicBlock *BB : depth_first(&F.getEntryBlock())) { + int InstructionNb = 0; for (Instruction &I1 : *BB) { + // Only hoist the first instructions in BB up to MaxDepthInBB. Hoisting + // deeper may increase the register pressure and compilation time. + if (MaxDepthInBB != -1 && InstructionNb++ >= MaxDepthInBB) + break; + + // Do not value number terminator instructions. + if (isa<TerminatorInst>(&I1)) + break; + if (auto *Load = dyn_cast<LoadInst>(&I1)) LI.insert(Load, VN); else if (auto *Store = dyn_cast<StoreInst>(&I1)) @@ -711,15 +915,14 @@ public: Intr->getIntrinsicID() == Intrinsic::assume) continue; } - if (Call->mayHaveSideEffects()) { - if (!OptForMinSize) - break; - // We may continue hoisting across calls which write to memory. - if (Call->mayThrow()) - break; - } + if (Call->mayHaveSideEffects()) + break; + + if (Call->isConvergent()) + break; + CI.insert(Call, VN); - } else if (OptForMinSize || !isa<GetElementPtrInst>(&I1)) + } else if (HoistingGeps || !isa<GetElementPtrInst>(&I1)) // Do not hoist scalars past calls that may write to memory because // that could result in spills later. geps are handled separately. // TODO: We can relax this for targets like AArch64 as they have more @@ -737,39 +940,6 @@ public: computeInsertionPoints(CI.getStoreVNTable(), HPL, InsKind::Store); return hoist(HPL); } - - bool run(Function &F) { - VN.setDomTree(DT); - VN.setAliasAnalysis(AA); - VN.setMemDep(MD); - bool Res = false; - - unsigned I = 0; - for (const BasicBlock *BB : depth_first(&F.getEntryBlock())) - DFSNumber.insert({BB, ++I}); - - // FIXME: use lazy evaluation of VN to avoid the fix-point computation. - while (1) { - // FIXME: only compute MemorySSA once. We need to update the analysis in - // the same time as transforming the code. - MemorySSA M(F, AA, DT); - MSSA = &M; - - auto HoistStat = hoistExpressions(F); - if (HoistStat.first + HoistStat.second == 0) { - return Res; - } - if (HoistStat.second > 0) { - // To address a limitation of the current GVN, we need to rerun the - // hoisting after we hoisted loads in order to be able to hoist all - // scalars dependent on the hoisted loads. Same for stores. - VN.clear(); - } - Res = true; - } - - return Res; - } }; class GVNHoistLegacyPass : public FunctionPass { @@ -781,11 +951,14 @@ public: } bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); auto &MD = getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); + auto &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); - GVNHoist G(&DT, &AA, &MD, F.optForMinSize()); + GVNHoist G(&DT, &AA, &MD, &MSSA); return G.run(F); } @@ -793,23 +966,25 @@ public: AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<AAResultsWrapperPass>(); AU.addRequired<MemoryDependenceWrapperPass>(); + AU.addRequired<MemorySSAWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); } }; } // namespace -PreservedAnalyses GVNHoistPass::run(Function &F, - AnalysisManager<Function> &AM) { +PreservedAnalyses GVNHoistPass::run(Function &F, FunctionAnalysisManager &AM) { DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); AliasAnalysis &AA = AM.getResult<AAManager>(F); MemoryDependenceResults &MD = AM.getResult<MemoryDependenceAnalysis>(F); - - GVNHoist G(&DT, &AA, &MD, F.optForMinSize()); + MemorySSA &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); + GVNHoist G(&DT, &AA, &MD, &MSSA); if (!G.run(F)) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<MemorySSAAnalysis>(); return PA; } @@ -817,6 +992,7 @@ char GVNHoistLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(GVNHoistLegacyPass, "gvn-hoist", "Early GVN Hoisting of Expressions", false, false) INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_END(GVNHoistLegacyPass, "gvn-hoist", diff --git a/contrib/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/contrib/llvm/lib/Transforms/Scalar/GuardWidening.cpp index 7686e65..b05ef00 100644 --- a/contrib/llvm/lib/Transforms/Scalar/GuardWidening.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/GuardWidening.cpp @@ -46,6 +46,7 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" @@ -653,7 +654,7 @@ bool GuardWideningImpl::combineRangeChecks( } PreservedAnalyses GuardWideningPass::run(Function &F, - AnalysisManager<Function> &AM) { + FunctionAnalysisManager &AM) { auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &LI = AM.getResult<LoopAnalysis>(F); auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); diff --git a/contrib/llvm/lib/Transforms/Scalar/IVUsersPrinter.cpp b/contrib/llvm/lib/Transforms/Scalar/IVUsersPrinter.cpp new file mode 100644 index 0000000..8075933 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/IVUsersPrinter.cpp @@ -0,0 +1,22 @@ +//===- IVUsersPrinter.cpp - Induction Variable Users Printer ----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/IVUsersPrinter.h" +#include "llvm/Analysis/IVUsers.h" +#include "llvm/Support/Debug.h" +using namespace llvm; + +#define DEBUG_TYPE "iv-users" + +PreservedAnalyses IVUsersPrinterPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + AM.getResult<IVUsersAnalysis>(L, AR).print(OS); + return PreservedAnalyses::all(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/contrib/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index cf3e7c5..1752fb7 100644 --- a/contrib/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -25,15 +25,13 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/IndVarSimplify.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/LoopPassManager.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/BasicBlock.h" @@ -49,6 +47,8 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -79,8 +79,12 @@ static cl::opt<ReplaceExitVal> ReplaceExitValue( clEnumValN(OnlyCheapRepl, "cheap", "only replace exit value when the cost is cheap"), clEnumValN(AlwaysRepl, "always", - "always replace exit value whenever possible"), - clEnumValEnd)); + "always replace exit value whenever possible"))); + +static cl::opt<bool> UsePostIncrementRanges( + "indvars-post-increment-ranges", cl::Hidden, + cl::desc("Use post increment control-dependent ranges in IndVarSimplify"), + cl::init(true)); namespace { struct RewritePhi; @@ -506,7 +510,8 @@ Value *IndVarSimplify::expandSCEVIfNeeded(SCEVExpander &Rewriter, const SCEV *S, /// constant operands at the beginning of the loop. void IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { // Check a pre-condition. - assert(L->isRecursivelyLCSSAForm(*DT) && "Indvars did not preserve LCSSA!"); + assert(L->isRecursivelyLCSSAForm(*DT, *LI) && + "Indvars did not preserve LCSSA!"); SmallVector<BasicBlock*, 8> ExitBlocks; L->getUniqueExitBlocks(ExitBlocks); @@ -880,7 +885,6 @@ class WidenIV { // Parameters PHINode *OrigPhi; Type *WideType; - bool IsSigned; // Context LoopInfo *LI; @@ -888,31 +892,70 @@ class WidenIV { ScalarEvolution *SE; DominatorTree *DT; + // Does the module have any calls to the llvm.experimental.guard intrinsic + // at all? If not we can avoid scanning instructions looking for guards. + bool HasGuards; + // Result PHINode *WidePhi; Instruction *WideInc; const SCEV *WideIncExpr; SmallVectorImpl<WeakVH> &DeadInsts; - SmallPtrSet<Instruction*,16> Widened; + SmallPtrSet<Instruction *,16> Widened; SmallVector<NarrowIVDefUse, 8> NarrowIVUsers; + enum ExtendKind { ZeroExtended, SignExtended, Unknown }; + // A map tracking the kind of extension used to widen each narrow IV + // and narrow IV user. + // Key: pointer to a narrow IV or IV user. + // Value: the kind of extension used to widen this Instruction. + DenseMap<AssertingVH<Instruction>, ExtendKind> ExtendKindMap; + + typedef std::pair<AssertingVH<Value>, AssertingVH<Instruction>> DefUserPair; + // A map with control-dependent ranges for post increment IV uses. The key is + // a pair of IV def and a use of this def denoting the context. The value is + // a ConstantRange representing possible values of the def at the given + // context. + DenseMap<DefUserPair, ConstantRange> PostIncRangeInfos; + + Optional<ConstantRange> getPostIncRangeInfo(Value *Def, + Instruction *UseI) { + DefUserPair Key(Def, UseI); + auto It = PostIncRangeInfos.find(Key); + return It == PostIncRangeInfos.end() + ? Optional<ConstantRange>(None) + : Optional<ConstantRange>(It->second); + } + + void calculatePostIncRanges(PHINode *OrigPhi); + void calculatePostIncRange(Instruction *NarrowDef, Instruction *NarrowUser); + void updatePostIncRangeInfo(Value *Def, Instruction *UseI, ConstantRange R) { + DefUserPair Key(Def, UseI); + auto It = PostIncRangeInfos.find(Key); + if (It == PostIncRangeInfos.end()) + PostIncRangeInfos.insert({Key, R}); + else + It->second = R.intersectWith(It->second); + } + public: WidenIV(const WideIVInfo &WI, LoopInfo *LInfo, ScalarEvolution *SEv, DominatorTree *DTree, - SmallVectorImpl<WeakVH> &DI) : + SmallVectorImpl<WeakVH> &DI, bool HasGuards) : OrigPhi(WI.NarrowIV), WideType(WI.WidestNativeType), - IsSigned(WI.IsSigned), LI(LInfo), L(LI->getLoopFor(OrigPhi->getParent())), SE(SEv), DT(DTree), + HasGuards(HasGuards), WidePhi(nullptr), WideInc(nullptr), WideIncExpr(nullptr), DeadInsts(DI) { assert(L->getHeader() == OrigPhi->getParent() && "Phi must be an IV"); + ExtendKindMap[OrigPhi] = WI.IsSigned ? SignExtended : ZeroExtended; } PHINode *createWideIV(SCEVExpander &Rewriter); @@ -926,9 +969,13 @@ protected: const SCEVAddRecExpr *WideAR); Instruction *cloneBitwiseIVUser(NarrowIVDefUse DU); - const SCEVAddRecExpr *getWideRecurrence(Instruction *NarrowUse); + ExtendKind getExtendKind(Instruction *I); - const SCEVAddRecExpr* getExtendedOperandRecurrence(NarrowIVDefUse DU); + typedef std::pair<const SCEVAddRecExpr *, ExtendKind> WidenedRecTy; + + WidenedRecTy getWideRecurrence(NarrowIVDefUse DU); + + WidenedRecTy getExtendedOperandRecurrence(NarrowIVDefUse DU); const SCEV *getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS, unsigned OpCode) const; @@ -1002,6 +1049,7 @@ Instruction *WidenIV::cloneBitwiseIVUser(NarrowIVDefUse DU) { // about the narrow operand yet so must insert a [sz]ext. It is probably loop // invariant and will be folded or hoisted. If it actually comes from a // widened IV, it should be removed during a future call to widenIVUse. + bool IsSigned = getExtendKind(NarrowDef) == SignExtended; Value *LHS = (NarrowUse->getOperand(0) == NarrowDef) ? WideDef : createExtendInst(NarrowUse->getOperand(0), WideType, @@ -1086,7 +1134,7 @@ Instruction *WidenIV::cloneArithmeticIVUser(NarrowIVDefUse DU, return WideUse == WideAR; }; - bool SignExtend = IsSigned; + bool SignExtend = getExtendKind(NarrowDef) == SignExtended; if (!GuessNonIVOperand(SignExtend)) { SignExtend = !SignExtend; if (!GuessNonIVOperand(SignExtend)) @@ -1112,6 +1160,12 @@ Instruction *WidenIV::cloneArithmeticIVUser(NarrowIVDefUse DU, return WideBO; } +WidenIV::ExtendKind WidenIV::getExtendKind(Instruction *I) { + auto It = ExtendKindMap.find(I); + assert(It != ExtendKindMap.end() && "Instruction not yet extended!"); + return It->second; +} + const SCEV *WidenIV::getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS, unsigned OpCode) const { if (OpCode == Instruction::Add) @@ -1127,15 +1181,16 @@ const SCEV *WidenIV::getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS, /// No-wrap operations can transfer sign extension of their result to their /// operands. Generate the SCEV value for the widened operation without /// actually modifying the IR yet. If the expression after extending the -/// operands is an AddRec for this loop, return it. -const SCEVAddRecExpr* WidenIV::getExtendedOperandRecurrence(NarrowIVDefUse DU) { +/// operands is an AddRec for this loop, return the AddRec and the kind of +/// extension used. +WidenIV::WidenedRecTy WidenIV::getExtendedOperandRecurrence(NarrowIVDefUse DU) { // Handle the common case of add<nsw/nuw> const unsigned OpCode = DU.NarrowUse->getOpcode(); // Only Add/Sub/Mul instructions supported yet. if (OpCode != Instruction::Add && OpCode != Instruction::Sub && OpCode != Instruction::Mul) - return nullptr; + return {nullptr, Unknown}; // One operand (NarrowDef) has already been extended to WideDef. Now determine // if extending the other will lead to a recurrence. @@ -1146,14 +1201,15 @@ const SCEVAddRecExpr* WidenIV::getExtendedOperandRecurrence(NarrowIVDefUse DU) { const SCEV *ExtendOperExpr = nullptr; const OverflowingBinaryOperator *OBO = cast<OverflowingBinaryOperator>(DU.NarrowUse); - if (IsSigned && OBO->hasNoSignedWrap()) + ExtendKind ExtKind = getExtendKind(DU.NarrowDef); + if (ExtKind == SignExtended && OBO->hasNoSignedWrap()) ExtendOperExpr = SE->getSignExtendExpr( SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx)), WideType); - else if(!IsSigned && OBO->hasNoUnsignedWrap()) + else if(ExtKind == ZeroExtended && OBO->hasNoUnsignedWrap()) ExtendOperExpr = SE->getZeroExtendExpr( SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx)), WideType); else - return nullptr; + return {nullptr, Unknown}; // When creating this SCEV expr, don't apply the current operations NSW or NUW // flags. This instruction may be guarded by control flow that the no-wrap @@ -1171,33 +1227,49 @@ const SCEVAddRecExpr* WidenIV::getExtendedOperandRecurrence(NarrowIVDefUse DU) { dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode(lhs, rhs, OpCode)); if (!AddRec || AddRec->getLoop() != L) - return nullptr; - return AddRec; + return {nullptr, Unknown}; + + return {AddRec, ExtKind}; } /// Is this instruction potentially interesting for further simplification after /// widening it's type? In other words, can the extend be safely hoisted out of /// the loop with SCEV reducing the value to a recurrence on the same loop. If -/// so, return the sign or zero extended recurrence. Otherwise return NULL. -const SCEVAddRecExpr *WidenIV::getWideRecurrence(Instruction *NarrowUse) { - if (!SE->isSCEVable(NarrowUse->getType())) - return nullptr; - - const SCEV *NarrowExpr = SE->getSCEV(NarrowUse); - if (SE->getTypeSizeInBits(NarrowExpr->getType()) - >= SE->getTypeSizeInBits(WideType)) { +/// so, return the extended recurrence and the kind of extension used. Otherwise +/// return {nullptr, Unknown}. +WidenIV::WidenedRecTy WidenIV::getWideRecurrence(NarrowIVDefUse DU) { + if (!SE->isSCEVable(DU.NarrowUse->getType())) + return {nullptr, Unknown}; + + const SCEV *NarrowExpr = SE->getSCEV(DU.NarrowUse); + if (SE->getTypeSizeInBits(NarrowExpr->getType()) >= + SE->getTypeSizeInBits(WideType)) { // NarrowUse implicitly widens its operand. e.g. a gep with a narrow // index. So don't follow this use. - return nullptr; + return {nullptr, Unknown}; } - const SCEV *WideExpr = IsSigned ? - SE->getSignExtendExpr(NarrowExpr, WideType) : - SE->getZeroExtendExpr(NarrowExpr, WideType); + const SCEV *WideExpr; + ExtendKind ExtKind; + if (DU.NeverNegative) { + WideExpr = SE->getSignExtendExpr(NarrowExpr, WideType); + if (isa<SCEVAddRecExpr>(WideExpr)) + ExtKind = SignExtended; + else { + WideExpr = SE->getZeroExtendExpr(NarrowExpr, WideType); + ExtKind = ZeroExtended; + } + } else if (getExtendKind(DU.NarrowDef) == SignExtended) { + WideExpr = SE->getSignExtendExpr(NarrowExpr, WideType); + ExtKind = SignExtended; + } else { + WideExpr = SE->getZeroExtendExpr(NarrowExpr, WideType); + ExtKind = ZeroExtended; + } const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(WideExpr); if (!AddRec || AddRec->getLoop() != L) - return nullptr; - return AddRec; + return {nullptr, Unknown}; + return {AddRec, ExtKind}; } /// This IV user cannot be widen. Replace this use of the original narrow IV @@ -1233,7 +1305,7 @@ bool WidenIV::widenLoopCompare(NarrowIVDefUse DU) { // // (A) == icmp slt i32 sext(%narrow), sext(%val) // == icmp slt i32 zext(%narrow), sext(%val) - + bool IsSigned = getExtendKind(DU.NarrowDef) == SignExtended; if (!(DU.NeverNegative || IsSigned == Cmp->isSigned())) return false; @@ -1258,6 +1330,8 @@ bool WidenIV::widenLoopCompare(NarrowIVDefUse DU) { /// Determine whether an individual user of the narrow IV can be widened. If so, /// return the wide clone of the user. Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { + assert(ExtendKindMap.count(DU.NarrowDef) && + "Should already know the kind of extension used to widen NarrowDef"); // Stop traversing the def-use chain at inner-loop phis or post-loop phis. if (PHINode *UsePhi = dyn_cast<PHINode>(DU.NarrowUse)) { @@ -1288,8 +1362,19 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { return nullptr; } } + + // This narrow use can be widened by a sext if it's non-negative or its narrow + // def was widended by a sext. Same for zext. + auto canWidenBySExt = [&]() { + return DU.NeverNegative || getExtendKind(DU.NarrowDef) == SignExtended; + }; + auto canWidenByZExt = [&]() { + return DU.NeverNegative || getExtendKind(DU.NarrowDef) == ZeroExtended; + }; + // Our raison d'etre! Eliminate sign and zero extension. - if (IsSigned ? isa<SExtInst>(DU.NarrowUse) : isa<ZExtInst>(DU.NarrowUse)) { + if ((isa<SExtInst>(DU.NarrowUse) && canWidenBySExt()) || + (isa<ZExtInst>(DU.NarrowUse) && canWidenByZExt())) { Value *NewDef = DU.WideDef; if (DU.NarrowUse->getType() != WideType) { unsigned CastWidth = SE->getTypeSizeInBits(DU.NarrowUse->getType()); @@ -1327,17 +1412,18 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { } // Does this user itself evaluate to a recurrence after widening? - const SCEVAddRecExpr *WideAddRec = getWideRecurrence(DU.NarrowUse); - if (!WideAddRec) - WideAddRec = getExtendedOperandRecurrence(DU); + WidenedRecTy WideAddRec = getExtendedOperandRecurrence(DU); + if (!WideAddRec.first) + WideAddRec = getWideRecurrence(DU); - if (!WideAddRec) { + assert((WideAddRec.first == nullptr) == (WideAddRec.second == Unknown)); + if (!WideAddRec.first) { // If use is a loop condition, try to promote the condition instead of // truncating the IV first. if (widenLoopCompare(DU)) return nullptr; - // This user does not evaluate to a recurence after widening, so don't + // This user does not evaluate to a recurrence after widening, so don't // follow it. Instead insert a Trunc to kill off the original use, // eventually isolating the original narrow IV so it can be removed. truncateIVUse(DU, DT, LI); @@ -1351,10 +1437,11 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { // Reuse the IV increment that SCEVExpander created as long as it dominates // NarrowUse. Instruction *WideUse = nullptr; - if (WideAddRec == WideIncExpr && Rewriter.hoistIVInc(WideInc, DU.NarrowUse)) + if (WideAddRec.first == WideIncExpr && + Rewriter.hoistIVInc(WideInc, DU.NarrowUse)) WideUse = WideInc; else { - WideUse = cloneIVUser(DU, WideAddRec); + WideUse = cloneIVUser(DU, WideAddRec.first); if (!WideUse) return nullptr; } @@ -1363,13 +1450,14 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { // evaluates to the same expression as the extended narrow use, but doesn't // absolutely guarantee it. Hence the following failsafe check. In rare cases // where it fails, we simply throw away the newly created wide use. - if (WideAddRec != SE->getSCEV(WideUse)) { + if (WideAddRec.first != SE->getSCEV(WideUse)) { DEBUG(dbgs() << "Wide use expression mismatch: " << *WideUse - << ": " << *SE->getSCEV(WideUse) << " != " << *WideAddRec << "\n"); + << ": " << *SE->getSCEV(WideUse) << " != " << *WideAddRec.first << "\n"); DeadInsts.emplace_back(WideUse); return nullptr; } + ExtendKindMap[DU.NarrowUse] = WideAddRec.second; // Returning WideUse pushes it on the worklist. return WideUse; } @@ -1378,7 +1466,7 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { /// void WidenIV::pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef) { const SCEV *NarrowSCEV = SE->getSCEV(NarrowDef); - bool NeverNegative = + bool NonNegativeDef = SE->isKnownPredicate(ICmpInst::ICMP_SGE, NarrowSCEV, SE->getConstant(NarrowSCEV->getType(), 0)); for (User *U : NarrowDef->users()) { @@ -1388,7 +1476,15 @@ void WidenIV::pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef) { if (!Widened.insert(NarrowUser).second) continue; - NarrowIVUsers.emplace_back(NarrowDef, NarrowUser, WideDef, NeverNegative); + bool NonNegativeUse = false; + if (!NonNegativeDef) { + // We might have a control-dependent range information for this context. + if (auto RangeInfo = getPostIncRangeInfo(NarrowDef, NarrowUser)) + NonNegativeUse = RangeInfo->getSignedMin().isNonNegative(); + } + + NarrowIVUsers.emplace_back(NarrowDef, NarrowUser, WideDef, + NonNegativeDef || NonNegativeUse); } } @@ -1408,9 +1504,9 @@ PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) { return nullptr; // Widen the induction variable expression. - const SCEV *WideIVExpr = IsSigned ? - SE->getSignExtendExpr(AddRec, WideType) : - SE->getZeroExtendExpr(AddRec, WideType); + const SCEV *WideIVExpr = getExtendKind(OrigPhi) == SignExtended + ? SE->getSignExtendExpr(AddRec, WideType) + : SE->getZeroExtendExpr(AddRec, WideType); assert(SE->getEffectiveSCEVType(WideIVExpr->getType()) == WideType && "Expect the new IV expression to preserve its type"); @@ -1428,6 +1524,19 @@ PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) { SE->properlyDominates(AddRec->getStepRecurrence(*SE), L->getHeader()) && "Loop header phi recurrence inputs do not dominate the loop"); + // Iterate over IV uses (including transitive ones) looking for IV increments + // of the form 'add nsw %iv, <const>'. For each increment and each use of + // the increment calculate control-dependent range information basing on + // dominating conditions inside of the loop (e.g. a range check inside of the + // loop). Calculated ranges are stored in PostIncRangeInfos map. + // + // Control-dependent range information is later used to prove that a narrow + // definition is not negative (see pushNarrowIVUsers). It's difficult to do + // this on demand because when pushNarrowIVUsers needs this information some + // of the dominating conditions might be already widened. + if (UsePostIncrementRanges) + calculatePostIncRanges(OrigPhi); + // The rewriter provides a value for the desired IV expression. This may // either find an existing phi or materialize a new one. Either way, we // expect a well-formed cyclic phi-with-increments. i.e. any operand not part @@ -1443,6 +1552,11 @@ PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) { WideInc = cast<Instruction>(WidePhi->getIncomingValueForBlock(LatchBlock)); WideIncExpr = SE->getSCEV(WideInc); + // Propagate the debug location associated with the original loop increment + // to the new (widened) increment. + auto *OrigInc = + cast<Instruction>(OrigPhi->getIncomingValueForBlock(LatchBlock)); + WideInc->setDebugLoc(OrigInc->getDebugLoc()); } DEBUG(dbgs() << "Wide IV: " << *WidePhi << "\n"); @@ -1472,6 +1586,114 @@ PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) { return WidePhi; } +/// Calculates control-dependent range for the given def at the given context +/// by looking at dominating conditions inside of the loop +void WidenIV::calculatePostIncRange(Instruction *NarrowDef, + Instruction *NarrowUser) { + using namespace llvm::PatternMatch; + + Value *NarrowDefLHS; + const APInt *NarrowDefRHS; + if (!match(NarrowDef, m_NSWAdd(m_Value(NarrowDefLHS), + m_APInt(NarrowDefRHS))) || + !NarrowDefRHS->isNonNegative()) + return; + + auto UpdateRangeFromCondition = [&] (Value *Condition, + bool TrueDest) { + CmpInst::Predicate Pred; + Value *CmpRHS; + if (!match(Condition, m_ICmp(Pred, m_Specific(NarrowDefLHS), + m_Value(CmpRHS)))) + return; + + CmpInst::Predicate P = + TrueDest ? Pred : CmpInst::getInversePredicate(Pred); + + auto CmpRHSRange = SE->getSignedRange(SE->getSCEV(CmpRHS)); + auto CmpConstrainedLHSRange = + ConstantRange::makeAllowedICmpRegion(P, CmpRHSRange); + auto NarrowDefRange = + CmpConstrainedLHSRange.addWithNoSignedWrap(*NarrowDefRHS); + + updatePostIncRangeInfo(NarrowDef, NarrowUser, NarrowDefRange); + }; + + auto UpdateRangeFromGuards = [&](Instruction *Ctx) { + if (!HasGuards) + return; + + for (Instruction &I : make_range(Ctx->getIterator().getReverse(), + Ctx->getParent()->rend())) { + Value *C = nullptr; + if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>(m_Value(C)))) + UpdateRangeFromCondition(C, /*TrueDest=*/true); + } + }; + + UpdateRangeFromGuards(NarrowUser); + + BasicBlock *NarrowUserBB = NarrowUser->getParent(); + // If NarrowUserBB is statically unreachable asking dominator queries may + // yield surprising results. (e.g. the block may not have a dom tree node) + if (!DT->isReachableFromEntry(NarrowUserBB)) + return; + + for (auto *DTB = (*DT)[NarrowUserBB]->getIDom(); + L->contains(DTB->getBlock()); + DTB = DTB->getIDom()) { + auto *BB = DTB->getBlock(); + auto *TI = BB->getTerminator(); + UpdateRangeFromGuards(TI); + + auto *BI = dyn_cast<BranchInst>(TI); + if (!BI || !BI->isConditional()) + continue; + + auto *TrueSuccessor = BI->getSuccessor(0); + auto *FalseSuccessor = BI->getSuccessor(1); + + auto DominatesNarrowUser = [this, NarrowUser] (BasicBlockEdge BBE) { + return BBE.isSingleEdge() && + DT->dominates(BBE, NarrowUser->getParent()); + }; + + if (DominatesNarrowUser(BasicBlockEdge(BB, TrueSuccessor))) + UpdateRangeFromCondition(BI->getCondition(), /*TrueDest=*/true); + + if (DominatesNarrowUser(BasicBlockEdge(BB, FalseSuccessor))) + UpdateRangeFromCondition(BI->getCondition(), /*TrueDest=*/false); + } +} + +/// Calculates PostIncRangeInfos map for the given IV +void WidenIV::calculatePostIncRanges(PHINode *OrigPhi) { + SmallPtrSet<Instruction *, 16> Visited; + SmallVector<Instruction *, 6> Worklist; + Worklist.push_back(OrigPhi); + Visited.insert(OrigPhi); + + while (!Worklist.empty()) { + Instruction *NarrowDef = Worklist.pop_back_val(); + + for (Use &U : NarrowDef->uses()) { + auto *NarrowUser = cast<Instruction>(U.getUser()); + + // Don't go looking outside the current loop. + auto *NarrowUserLoop = (*LI)[NarrowUser->getParent()]; + if (!NarrowUserLoop || !L->contains(NarrowUserLoop)) + continue; + + if (!Visited.insert(NarrowUser).second) + continue; + + Worklist.push_back(NarrowUser); + + calculatePostIncRange(NarrowDef, NarrowUser); + } + } +} + //===----------------------------------------------------------------------===// // Live IV Reduction - Minimize IVs live across the loop. //===----------------------------------------------------------------------===// @@ -1514,6 +1736,10 @@ void IndVarSimplify::simplifyAndExtend(Loop *L, LoopInfo *LI) { SmallVector<WideIVInfo, 8> WideIVs; + auto *GuardDecl = L->getBlocks()[0]->getModule()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + bool HasGuards = GuardDecl && !GuardDecl->use_empty(); + SmallVector<PHINode*, 8> LoopPhis; for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) { LoopPhis.push_back(cast<PHINode>(I)); @@ -1543,7 +1769,7 @@ void IndVarSimplify::simplifyAndExtend(Loop *L, } while(!LoopPhis.empty()); for (; !WideIVs.empty(); WideIVs.pop_back()) { - WidenIV Widener(WideIVs.back(), LI, SE, DT, DeadInsts); + WidenIV Widener(WideIVs.back(), LI, SE, DT, DeadInsts, HasGuards); if (PHINode *WidePhi = Widener.createWideIV(Rewriter)) { Changed = true; LoopPhis.push_back(WidePhi); @@ -1870,7 +2096,7 @@ static Value *genLoopLimit(PHINode *IndVar, const SCEV *IVCount, Loop *L, return Builder.CreateGEP(nullptr, GEPBase, GEPOffset, "lftr.limit"); } else { // In any other case, convert both IVInit and IVCount to integers before - // comparing. This may result in SCEV expension of pointers, but in practice + // comparing. This may result in SCEV expansion of pointers, but in practice // SCEV will fold the pointer arithmetic away as such: // BECount = (IVEnd - IVInit - 1) => IVLimit = IVInit (postinc). // @@ -1963,6 +2189,11 @@ linearFunctionTestReplace(Loop *L, IRBuilder<> Builder(BI); + // The new loop exit condition should reuse the debug location of the + // original loop exit condition. + if (auto *Cond = dyn_cast<Instruction>(BI->getCondition())) + Builder.SetCurrentDebugLocation(Cond->getDebugLoc()); + // LFTR can ignore IV overflow and truncate to the width of // BECount. This avoids materializing the add(zext(add)) expression. unsigned CmpIndVarSize = SE->getTypeSizeInBits(CmpIndVar->getType()); @@ -1992,8 +2223,36 @@ linearFunctionTestReplace(Loop *L, DEBUG(dbgs() << " Widen RHS:\t" << *ExitCnt << "\n"); } else { - CmpIndVar = Builder.CreateTrunc(CmpIndVar, ExitCnt->getType(), - "lftr.wideiv"); + // We try to extend trip count first. If that doesn't work we truncate IV. + // Zext(trunc(IV)) == IV implies equivalence of the following two: + // Trunc(IV) == ExitCnt and IV == zext(ExitCnt). Similarly for sext. If + // one of the two holds, extend the trip count, otherwise we truncate IV. + bool Extended = false; + const SCEV *IV = SE->getSCEV(CmpIndVar); + const SCEV *ZExtTrunc = + SE->getZeroExtendExpr(SE->getTruncateExpr(SE->getSCEV(CmpIndVar), + ExitCnt->getType()), + CmpIndVar->getType()); + + if (ZExtTrunc == IV) { + Extended = true; + ExitCnt = Builder.CreateZExt(ExitCnt, IndVar->getType(), + "wide.trip.count"); + } else { + const SCEV *SExtTrunc = + SE->getSignExtendExpr(SE->getTruncateExpr(SE->getSCEV(CmpIndVar), + ExitCnt->getType()), + CmpIndVar->getType()); + if (SExtTrunc == IV) { + Extended = true; + ExitCnt = Builder.CreateSExt(ExitCnt, IndVar->getType(), + "wide.trip.count"); + } + } + + if (!Extended) + CmpIndVar = Builder.CreateTrunc(CmpIndVar, ExitCnt->getType(), + "lftr.wideiv"); } } Value *Cond = Builder.CreateICmp(P, CmpIndVar, ExitCnt, "exitcond"); @@ -2025,7 +2284,7 @@ void IndVarSimplify::sinkUnusedInvariants(Loop *L) { BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) return; - Instruction *InsertPt = &*ExitBlock->getFirstInsertionPt(); + BasicBlock::iterator InsertPt = ExitBlock->getFirstInsertionPt(); BasicBlock::iterator I(Preheader->getTerminator()); while (I != Preheader->begin()) { --I; @@ -2094,9 +2353,9 @@ void IndVarSimplify::sinkUnusedInvariants(Loop *L) { Done = true; } - ToMove->moveBefore(InsertPt); + ToMove->moveBefore(*ExitBlock, InsertPt); if (Done) break; - InsertPt = ToMove; + InsertPt = ToMove->getIterator(); } } @@ -2106,7 +2365,8 @@ void IndVarSimplify::sinkUnusedInvariants(Loop *L) { bool IndVarSimplify::run(Loop *L) { // We need (and expect!) the incoming loop to be in LCSSA. - assert(L->isRecursivelyLCSSAForm(*DT) && "LCSSA required to run indvars!"); + assert(L->isRecursivelyLCSSAForm(*DT, *LI) && + "LCSSA required to run indvars!"); // If LoopSimplify form is not available, stay out of trouble. Some notes: // - LSR currently only supports LoopSimplify-form loops. Indvars' @@ -2199,7 +2459,8 @@ bool IndVarSimplify::run(Loop *L) { Changed |= DeleteDeadPHIs(L->getHeader(), TLI); // Check a post-condition. - assert(L->isRecursivelyLCSSAForm(*DT) && "Indvars did not preserve LCSSA!"); + assert(L->isRecursivelyLCSSAForm(*DT, *LI) && + "Indvars did not preserve LCSSA!"); // Verify that LFTR, and any other change have not interfered with SCEV's // ability to compute trip count. @@ -2221,23 +2482,13 @@ bool IndVarSimplify::run(Loop *L) { return Changed; } -PreservedAnalyses IndVarSimplifyPass::run(Loop &L, AnalysisManager<Loop> &AM) { - auto &FAM = AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); +PreservedAnalyses IndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { Function *F = L.getHeader()->getParent(); const DataLayout &DL = F->getParent()->getDataLayout(); - auto *LI = FAM.getCachedResult<LoopAnalysis>(*F); - auto *SE = FAM.getCachedResult<ScalarEvolutionAnalysis>(*F); - auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(*F); - - assert((LI && SE && DT) && - "Analyses required for indvarsimplify not available!"); - - // Optional analyses. - auto *TTI = FAM.getCachedResult<TargetIRAnalysis>(*F); - auto *TLI = FAM.getCachedResult<TargetLibraryAnalysis>(*F); - - IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI); + IndVarSimplify IVS(&AR.LI, &AR.SE, &AR.DT, DL, &AR.TLI, &AR.TTI); if (!IVS.run(&L)) return PreservedAnalyses::all(); diff --git a/contrib/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/contrib/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index ec7f09a..8e81541 100644 --- a/contrib/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -43,21 +43,16 @@ #include "llvm/ADT/Optional.h" #include "llvm/Analysis/BranchProbabilityInfo.h" -#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" -#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/IR/ValueHandle.h" -#include "llvm/IR/Verifier.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -65,8 +60,7 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopUtils.h" -#include "llvm/Transforms/Utils/SimplifyIndVar.h" -#include "llvm/Transforms/Utils/UnrollLoop.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" using namespace llvm; @@ -82,6 +76,11 @@ static cl::opt<bool> PrintRangeChecks("irce-print-range-checks", cl::Hidden, static cl::opt<int> MaxExitProbReciprocal("irce-max-exit-prob-reciprocal", cl::Hidden, cl::init(10)); +static cl::opt<bool> SkipProfitabilityChecks("irce-skip-profitability-checks", + cl::Hidden, cl::init(false)); + +static const char *ClonedLoopTag = "irce.loop.clone"; + #define DEBUG_TYPE "irce" namespace { @@ -152,11 +151,10 @@ public: OS << " Operand: " << getCheckUse()->getOperandNo() << "\n"; } -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + LLVM_DUMP_METHOD void dump() { print(dbgs()); } -#endif Use *getCheckUse() const { return CheckUse; } @@ -276,7 +274,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, case ICmpInst::ICMP_SLE: std::swap(LHS, RHS); - // fallthrough + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SGE: if (match(RHS, m_ConstantInt<0>())) { Index = LHS; @@ -286,7 +284,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, case ICmpInst::ICMP_SLT: std::swap(LHS, RHS); - // fallthrough + LLVM_FALLTHROUGH; case ICmpInst::ICMP_SGT: if (match(RHS, m_ConstantInt<-1>())) { Index = LHS; @@ -302,7 +300,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, case ICmpInst::ICMP_ULT: std::swap(LHS, RHS); - // fallthrough + LLVM_FALLTHROUGH; case ICmpInst::ICMP_UGT: if (IsNonNegativeAndNotLoopVarying(LHS)) { Index = RHS; @@ -392,7 +390,8 @@ void InductiveRangeCheck::extractRangeChecksFromBranch( BranchProbability LikelyTaken(15, 16); - if (BPI.getEdgeProbability(BI->getParent(), (unsigned)0) < LikelyTaken) + if (!SkipProfitabilityChecks && + BPI.getEdgeProbability(BI->getParent(), (unsigned)0) < LikelyTaken) return; SmallPtrSet<Value *, 8> Visited; @@ -400,6 +399,34 @@ void InductiveRangeCheck::extractRangeChecksFromBranch( Checks, Visited); } +// Add metadata to the loop L to disable loop optimizations. Callers need to +// confirm that optimizing loop L is not beneficial. +static void DisableAllLoopOptsOnLoop(Loop &L) { + // We do not care about any existing loopID related metadata for L, since we + // are setting all loop metadata to false. + LLVMContext &Context = L.getHeader()->getContext(); + // Reserve first location for self reference to the LoopID metadata node. + MDNode *Dummy = MDNode::get(Context, {}); + MDNode *DisableUnroll = MDNode::get( + Context, {MDString::get(Context, "llvm.loop.unroll.disable")}); + Metadata *FalseVal = + ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0)); + MDNode *DisableVectorize = MDNode::get( + Context, + {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal}); + MDNode *DisableLICMVersioning = MDNode::get( + Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")}); + MDNode *DisableDistribution= MDNode::get( + Context, + {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal}); + MDNode *NewLoopID = + MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize, + DisableLICMVersioning, DisableDistribution}); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + L.setLoopID(NewLoopID); +} + namespace { // Keeps track of the structure of a loop. This is similar to llvm::Loop, @@ -515,6 +542,11 @@ class LoopConstrainer { // void cloneLoop(ClonedLoop &CLResult, const char *Tag) const; + // Create the appropriate loop structure needed to describe a cloned copy of + // `Original`. The clone is described by `VM`. + Loop *createClonedLoopStructure(Loop *Original, Loop *Parent, + ValueToValueMapTy &VM); + // Rewrite the iteration space of the loop denoted by (LS, Preheader). The // iteration space of the rewritten loop ends at ExitLoopAt. The start of the // iteration space is not changed. `ExitLoopAt' is assumed to be slt @@ -566,10 +598,12 @@ class LoopConstrainer { Function &F; LLVMContext &Ctx; ScalarEvolution &SE; + DominatorTree &DT; + LPPassManager &LPM; + LoopInfo &LI; // Information about the original loop we started out with. Loop &OriginalLoop; - LoopInfo &OriginalLoopInfo; const SCEV *LatchTakenCount; BasicBlock *OriginalPreheader; @@ -585,12 +619,13 @@ class LoopConstrainer { LoopStructure MainLoopStructure; public: - LoopConstrainer(Loop &L, LoopInfo &LI, const LoopStructure &LS, - ScalarEvolution &SE, InductiveRangeCheck::Range R) + LoopConstrainer(Loop &L, LoopInfo &LI, LPPassManager &LPM, + const LoopStructure &LS, ScalarEvolution &SE, + DominatorTree &DT, InductiveRangeCheck::Range R) : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), - SE(SE), OriginalLoop(L), OriginalLoopInfo(LI), LatchTakenCount(nullptr), - OriginalPreheader(nullptr), MainLoopPreheader(nullptr), Range(R), - MainLoopStructure(LS) {} + SE(SE), DT(DT), LPM(LPM), LI(LI), OriginalLoop(L), + LatchTakenCount(nullptr), OriginalPreheader(nullptr), + MainLoopPreheader(nullptr), Range(R), MainLoopStructure(LS) {} // Entry point for the algorithm. Returns true on success. bool run(); @@ -622,9 +657,19 @@ static bool CanBeSMin(ScalarEvolution &SE, const SCEV *S) { Optional<LoopStructure> LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BPI, Loop &L, const char *&FailureReason) { - assert(L.isLoopSimplifyForm() && "should follow from addRequired<>"); + if (!L.isLoopSimplifyForm()) { + FailureReason = "loop not in LoopSimplify form"; + return None; + } BasicBlock *Latch = L.getLoopLatch(); + assert(Latch && "Simplified loops only have one latch!"); + + if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) { + FailureReason = "loop has already been cloned"; + return None; + } + if (!L.isLoopExiting(Latch)) { FailureReason = "no loop latch"; return None; @@ -648,7 +693,8 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP BranchProbability ExitProbability = BPI.getEdgeProbability(LatchBr->getParent(), LatchBrExitIdx); - if (ExitProbability > BranchProbability(1, MaxExitProbReciprocal)) { + if (!SkipProfitabilityChecks && + ExitProbability > BranchProbability(1, MaxExitProbReciprocal)) { FailureReason = "short running loop, not profitable"; return None; } @@ -907,6 +953,11 @@ void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, return static_cast<Value *>(It->second); }; + auto *ClonedLatch = + cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch())); + ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag, + MDNode::get(Ctx, {})); + Result.Structure = MainLoopStructure.map(GetClonedValue); Result.Structure.Tag = Tag; @@ -924,17 +975,15 @@ void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, // to be edited to reflect that. No phi nodes need to be introduced because // the loop is in LCSSA. - for (auto SBBI = succ_begin(OriginalBB), SBBE = succ_end(OriginalBB); - SBBI != SBBE; ++SBBI) { - - if (OriginalLoop.contains(*SBBI)) + for (auto *SBB : successors(OriginalBB)) { + if (OriginalLoop.contains(SBB)) continue; // not an exit block - for (Instruction &I : **SBBI) { - if (!isa<PHINode>(&I)) + for (Instruction &I : *SBB) { + auto *PN = dyn_cast<PHINode>(&I); + if (!PN) break; - PHINode *PN = cast<PHINode>(&I); Value *OldIncoming = PN->getIncomingValueForBlock(OriginalBB); PN->addIncoming(GetClonedValue(OldIncoming), ClonedBB); } @@ -1020,11 +1069,11 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( RewrittenRangeInfo RRI; - auto BBInsertLocation = std::next(Function::iterator(LS.Latch)); + BasicBlock *BBInsertLocation = LS.Latch->getNextNode(); RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector", - &F, &*BBInsertLocation); + &F, BBInsertLocation); RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F, - &*BBInsertLocation); + BBInsertLocation); BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator()); bool Increasing = LS.IndVarIncreasing; @@ -1067,11 +1116,10 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( // each of the PHI nodes in the loop header. This feeds into the initial // value of the same PHI nodes if/when we continue execution. for (Instruction &I : *LS.Header) { - if (!isa<PHINode>(&I)) + auto *PN = dyn_cast<PHINode>(&I); + if (!PN) break; - PHINode *PN = cast<PHINode>(&I); - PHINode *NewPHI = PHINode::Create(PN->getType(), 2, PN->getName() + ".copy", BranchToContinuation); @@ -1104,11 +1152,10 @@ void LoopConstrainer::rewriteIncomingValuesForPHIs( unsigned PHIIndex = 0; for (Instruction &I : *LS.Header) { - if (!isa<PHINode>(&I)) + auto *PN = dyn_cast<PHINode>(&I); + if (!PN) break; - PHINode *PN = cast<PHINode>(&I); - for (unsigned i = 0, e = PN->getNumIncomingValues(); i < e; ++i) if (PN->getIncomingBlock(i) == ContinuationBlock) PN->setIncomingValue(i, RRI.PHIValuesAtPseudoExit[PHIIndex++]); @@ -1125,10 +1172,10 @@ BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, BranchInst::Create(LS.Header, Preheader); for (Instruction &I : *LS.Header) { - if (!isa<PHINode>(&I)) + auto *PN = dyn_cast<PHINode>(&I); + if (!PN) break; - PHINode *PN = cast<PHINode>(&I); for (unsigned i = 0, e = PN->getNumIncomingValues(); i < e; ++i) replacePHIBlock(PN, OldPreheader, Preheader); } @@ -1142,7 +1189,23 @@ void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) { return; for (BasicBlock *BB : BBs) - ParentLoop->addBasicBlockToLoop(BB, OriginalLoopInfo); + ParentLoop->addBasicBlockToLoop(BB, LI); +} + +Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, + ValueToValueMapTy &VM) { + Loop &New = LPM.addLoop(Parent); + + // Add all of the blocks in Original to the new loop. + for (auto *BB : Original->blocks()) + if (LI.getLoopFor(BB) == Original) + New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI); + + // Add all of the subloops to the new loop. + for (Loop *SubLoop : *Original) + createClonedLoopStructure(SubLoop, &New, VM); + + return &New; } bool LoopConstrainer::run() { @@ -1266,8 +1329,31 @@ bool LoopConstrainer::run() { std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr); addToParentLoopIfNeeded(makeArrayRef(std::begin(NewBlocks), NewBlocksEnd)); - addToParentLoopIfNeeded(PreLoop.Blocks); - addToParentLoopIfNeeded(PostLoop.Blocks); + + DT.recalculate(F); + + if (!PreLoop.Blocks.empty()) { + auto *L = createClonedLoopStructure( + &OriginalLoop, OriginalLoop.getParentLoop(), PreLoop.Map); + formLCSSARecursively(*L, DT, &LI, &SE); + simplifyLoop(L, &DT, &LI, &SE, nullptr, true); + // Pre loops are slow paths, we do not need to perform any loop + // optimizations on them. + DisableAllLoopOptsOnLoop(*L); + } + + if (!PostLoop.Blocks.empty()) { + auto *L = createClonedLoopStructure( + &OriginalLoop, OriginalLoop.getParentLoop(), PostLoop.Map); + formLCSSARecursively(*L, DT, &LI, &SE); + simplifyLoop(L, &DT, &LI, &SE, nullptr, true); + // Post loops are slow paths, we do not need to perform any loop + // optimizations on them. + DisableAllLoopOptsOnLoop(*L); + } + + formLCSSARecursively(OriginalLoop, DT, &LI, &SE); + simplifyLoop(&OriginalLoop, &DT, &LI, &SE, nullptr, true); return true; } @@ -1439,8 +1525,9 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { if (!SafeIterRange.hasValue()) return false; - LoopConstrainer LC(*L, getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), LS, - SE, SafeIterRange.getValue()); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LoopConstrainer LC(*L, getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), LPM, + LS, SE, DT, SafeIterRange.getValue()); bool Changed = LC.run(); if (Changed) { diff --git a/contrib/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/contrib/llvm/lib/Transforms/Scalar/JumpThreading.cpp index 55ffc23..1870c3d 100644 --- a/contrib/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -134,7 +134,7 @@ bool JumpThreading::runOnFunction(Function &F) { } PreservedAnalyses JumpThreadingPass::run(Function &F, - AnalysisManager<Function> &AM) { + FunctionAnalysisManager &AM) { auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &LVI = AM.getResult<LazyValueAnalysis>(F); @@ -951,12 +951,17 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { // Scan a few instructions up from the load, to see if it is obviously live at // the entry to its block. BasicBlock::iterator BBIt(LI); - + bool IsLoadCSE; if (Value *AvailableVal = - FindAvailableLoadedValue(LI, LoadBB, BBIt, DefMaxInstsToScan)) { + FindAvailableLoadedValue(LI, LoadBB, BBIt, DefMaxInstsToScan, nullptr, &IsLoadCSE)) { // If the value of the load is locally available within the block, just use // it. This frequently occurs for reg2mem'd allocas. + if (IsLoadCSE) { + LoadInst *NLI = cast<LoadInst>(AvailableVal); + combineMetadataForCSE(NLI, LI); + }; + // If the returned value is the load itself, replace with an undef. This can // only happen in dead loops. if (AvailableVal == LI) AvailableVal = UndefValue::get(LI->getType()); @@ -983,6 +988,7 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { typedef SmallVector<std::pair<BasicBlock*, Value*>, 8> AvailablePredsTy; AvailablePredsTy AvailablePreds; BasicBlock *OneUnavailablePred = nullptr; + SmallVector<LoadInst*, 8> CSELoads; // If we got here, the loaded value is transparent through to the start of the // block. Check to see if it is available in any of the predecessor blocks. @@ -993,17 +999,17 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { // Scan the predecessor to see if the value is available in the pred. BBIt = PredBB->end(); - AAMDNodes ThisAATags; Value *PredAvailable = FindAvailableLoadedValue(LI, PredBB, BBIt, DefMaxInstsToScan, - nullptr, &ThisAATags); + nullptr, + &IsLoadCSE); if (!PredAvailable) { OneUnavailablePred = PredBB; continue; } - // If AA tags disagree or are not present, forget about them. - if (AATags != ThisAATags) AATags = AAMDNodes(); + if (IsLoadCSE) + CSELoads.push_back(cast<LoadInst>(PredAvailable)); // If so, this load is partially redundant. Remember this info so that we // can create a PHI node. @@ -1101,6 +1107,10 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { PN->addIncoming(PredV, I->first); } + for (LoadInst *PredLI : CSELoads) { + combineMetadataForCSE(PredLI, LI); + } + LI->replaceAllUsesWith(PN); LI->eraseFromParent(); @@ -1157,8 +1167,7 @@ FindMostPopularDest(BasicBlock *BB, for (unsigned i = 0; ; ++i) { assert(i != TI->getNumSuccessors() && "Didn't find any successor!"); - if (std::find(SamePopularity.begin(), SamePopularity.end(), - TI->getSuccessor(i)) == SamePopularity.end()) + if (!is_contained(SamePopularity, TI->getSuccessor(i))) continue; MostPopularDest = TI->getSuccessor(i); @@ -1594,7 +1603,7 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, } /// Create a new basic block that will be the predecessor of BB and successor of -/// all blocks in Preds. When profile data is availble, update the frequency of +/// all blocks in Preds. When profile data is available, update the frequency of /// this new block. BasicBlock *JumpThreadingPass::SplitBlockPreds(BasicBlock *BB, ArrayRef<BasicBlock *> Preds, @@ -1615,6 +1624,23 @@ BasicBlock *JumpThreadingPass::SplitBlockPreds(BasicBlock *BB, return PredBB; } +bool JumpThreadingPass::doesBlockHaveProfileData(BasicBlock *BB) { + const TerminatorInst *TI = BB->getTerminator(); + assert(TI->getNumSuccessors() > 1 && "not a split"); + + MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof); + if (!WeightsNode) + return false; + + MDString *MDName = cast<MDString>(WeightsNode->getOperand(0)); + if (MDName->getString() != "branch_weights") + return false; + + // Ensure there are weights for all of the successors. Note that the first + // operand to the metadata node is a name, not a weight. + return WeightsNode->getNumOperands() == TI->getNumSuccessors() + 1; +} + /// Update the block frequency of BB and branch weight and the metadata on the /// edge BB->SuccBB. This is done by scaling the weight of BB->SuccBB by 1 - /// Freq(PredBB->BB) / Freq(BB->SuccBB). @@ -1665,7 +1691,41 @@ void JumpThreadingPass::UpdateBlockFreqAndEdgeWeight(BasicBlock *PredBB, for (int I = 0, E = BBSuccProbs.size(); I < E; I++) BPI->setEdgeProbability(BB, I, BBSuccProbs[I]); - if (BBSuccProbs.size() >= 2) { + // Update the profile metadata as well. + // + // Don't do this if the profile of the transformed blocks was statically + // estimated. (This could occur despite the function having an entry + // frequency in completely cold parts of the CFG.) + // + // In this case we don't want to suggest to subsequent passes that the + // calculated weights are fully consistent. Consider this graph: + // + // check_1 + // 50% / | + // eq_1 | 50% + // \ | + // check_2 + // 50% / | + // eq_2 | 50% + // \ | + // check_3 + // 50% / | + // eq_3 | 50% + // \ | + // + // Assuming the blocks check_* all compare the same value against 1, 2 and 3, + // the overall probabilities are inconsistent; the total probability that the + // value is either 1, 2 or 3 is 150%. + // + // As a consequence if we thread eq_1 -> check_2 to check_3, check_2->check_3 + // becomes 0%. This is even worse if the edge whose probability becomes 0% is + // the loop exit edge. Then based solely on static estimation we would assume + // the loop was extremely hot. + // + // FIXME this locally as well so that BPI and BFI are consistent as well. We + // shouldn't make edges extremely likely or unlikely based solely on static + // estimation. + if (BBSuccProbs.size() >= 2 && doesBlockHaveProfileData(BB)) { SmallVector<uint32_t, 4> Weights; for (auto Prob : BBSuccProbs) Weights.push_back(Prob.getNumerator()); diff --git a/contrib/llvm/lib/Transforms/Scalar/LICM.cpp b/contrib/llvm/lib/Transforms/Scalar/LICM.cpp index cdd17fc..f51d11c 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LICM.cpp @@ -41,8 +41,8 @@ #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/LoopPassManager.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -61,6 +61,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" @@ -84,14 +85,17 @@ static bool inSubLoop(BasicBlock *BB, Loop *CurLoop, LoopInfo *LI); static bool isNotUsedInLoop(const Instruction &I, const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo); static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, - const LoopSafetyInfo *SafetyInfo); + const LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE); static bool sink(Instruction &I, const LoopInfo *LI, const DominatorTree *DT, const Loop *CurLoop, AliasSetTracker *CurAST, - const LoopSafetyInfo *SafetyInfo); -static bool isSafeToExecuteUnconditionally(const Instruction &Inst, + const LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE); +static bool isSafeToExecuteUnconditionally(Instruction &Inst, const DominatorTree *DT, const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE, const Instruction *CtxI = nullptr); static bool pointerInvalidatedByLoop(Value *V, uint64_t Size, const AAMDNodes &AAInfo, @@ -100,15 +104,12 @@ static Instruction * CloneInstructionInExitBlock(Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI, const LoopSafetyInfo *SafetyInfo); -static bool canSinkOrHoistInst(Instruction &I, AliasAnalysis *AA, - DominatorTree *DT, TargetLibraryInfo *TLI, - Loop *CurLoop, AliasSetTracker *CurAST, - LoopSafetyInfo *SafetyInfo); namespace { struct LoopInvariantCodeMotion { bool runOnLoop(Loop *L, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, - TargetLibraryInfo *TLI, ScalarEvolution *SE, bool DeleteAST); + TargetLibraryInfo *TLI, ScalarEvolution *SE, + OptimizationRemarkEmitter *ORE, bool DeleteAST); DenseMap<Loop *, AliasSetTracker *> &getLoopToAliasSetMap() { return LoopToAliasSetMap; @@ -128,16 +129,27 @@ struct LegacyLICMPass : public LoopPass { } bool runOnLoop(Loop *L, LPPassManager &LPM) override { - if (skipLoop(L)) + if (skipLoop(L)) { + // If we have run LICM on a previous loop but now we are skipping + // (because we've hit the opt-bisect limit), we need to clear the + // loop alias information. + for (auto <AS : LICM.getLoopToAliasSetMap()) + delete LTAS.second; + LICM.getLoopToAliasSetMap().clear(); return false; + } auto *SE = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); + // For the old PM, we can't use OptimizationRemarkEmitter as an analysis + // pass. Function analyses need to be preserved across loop transformations + // but ORE cannot be preserved (see comment before the pass definition). + OptimizationRemarkEmitter ORE(L->getHeader()->getParent()); return LICM.runOnLoop(L, &getAnalysis<AAResultsWrapperPass>().getAAResults(), &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), - SE ? &SE->getSE() : nullptr, false); + SE ? &SE->getSE() : nullptr, &ORE, false); } /// This transformation requires natural loop information & requires that @@ -173,21 +185,20 @@ private: }; } -PreservedAnalyses LICMPass::run(Loop &L, AnalysisManager<Loop> &AM) { +PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, LPMUpdater &) { const auto &FAM = - AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); + AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); Function *F = L.getHeader()->getParent(); - auto *AA = FAM.getCachedResult<AAManager>(*F); - auto *LI = FAM.getCachedResult<LoopAnalysis>(*F); - auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(*F); - auto *TLI = FAM.getCachedResult<TargetLibraryAnalysis>(*F); - auto *SE = FAM.getCachedResult<ScalarEvolutionAnalysis>(*F); - assert((AA && LI && DT && TLI && SE) && "Analyses for LICM not available"); + auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F); + // FIXME: This should probably be optional rather than required. + if (!ORE) + report_fatal_error("LICM: OptimizationRemarkEmitterAnalysis not " + "cached at a higher level"); LoopInvariantCodeMotion LICM; - - if (!LICM.runOnLoop(&L, AA, LI, DT, TLI, SE, true)) + if (!LICM.runOnLoop(&L, &AR.AA, &AR.LI, &AR.DT, &AR.TLI, &AR.SE, ORE, true)) return PreservedAnalyses::all(); // FIXME: There is no setPreservesCFG in the new PM. When that becomes @@ -214,7 +225,9 @@ Pass *llvm::createLICMPass() { return new LegacyLICMPass(); } bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, TargetLibraryInfo *TLI, - ScalarEvolution *SE, bool DeleteAST) { + ScalarEvolution *SE, + OptimizationRemarkEmitter *ORE, + bool DeleteAST) { bool Changed = false; assert(L->isLCSSAForm(*DT) && "Loop is not in LCSSA form."); @@ -240,31 +253,54 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AliasAnalysis *AA, // if (L->hasDedicatedExits()) Changed |= sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, L, - CurAST, &SafetyInfo); + CurAST, &SafetyInfo, ORE); if (Preheader) Changed |= hoistRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, L, - CurAST, &SafetyInfo); + CurAST, &SafetyInfo, ORE); // Now that all loop invariants have been removed from the loop, promote any // memory references to scalars that we can. - if (!DisablePromotion && (Preheader || L->hasDedicatedExits())) { + // Don't sink stores from loops without dedicated block exits. Exits + // containing indirect branches are not transformed by loop simplify, + // make sure we catch that. An additional load may be generated in the + // preheader for SSA updater, so also avoid sinking when no preheader + // is available. + if (!DisablePromotion && Preheader && L->hasDedicatedExits()) { + // Figure out the loop exits and their insertion points SmallVector<BasicBlock *, 8> ExitBlocks; - SmallVector<Instruction *, 8> InsertPts; - PredIteratorCache PIC; - - // Loop over all of the alias sets in the tracker object. - for (AliasSet &AS : *CurAST) - Changed |= promoteLoopAccessesToScalars( - AS, ExitBlocks, InsertPts, PIC, LI, DT, TLI, L, CurAST, &SafetyInfo); - - // Once we have promoted values across the loop body we have to recursively - // reform LCSSA as any nested loop may now have values defined within the - // loop used in the outer loop. - // FIXME: This is really heavy handed. It would be a bit better to use an - // SSAUpdater strategy during promotion that was LCSSA aware and reformed - // it as it went. - if (Changed) { - formLCSSARecursively(*L, *DT, LI, SE); + L->getUniqueExitBlocks(ExitBlocks); + + // We can't insert into a catchswitch. + bool HasCatchSwitch = llvm::any_of(ExitBlocks, [](BasicBlock *Exit) { + return isa<CatchSwitchInst>(Exit->getTerminator()); + }); + + if (!HasCatchSwitch) { + SmallVector<Instruction *, 8> InsertPts; + InsertPts.reserve(ExitBlocks.size()); + for (BasicBlock *ExitBlock : ExitBlocks) + InsertPts.push_back(&*ExitBlock->getFirstInsertionPt()); + + PredIteratorCache PIC; + + bool Promoted = false; + + // Loop over all of the alias sets in the tracker object. + for (AliasSet &AS : *CurAST) + Promoted |= + promoteLoopAccessesToScalars(AS, ExitBlocks, InsertPts, PIC, LI, DT, + TLI, L, CurAST, &SafetyInfo, ORE); + + // Once we have promoted values across the loop body we have to + // recursively reform LCSSA as any nested loop may now have values defined + // within the loop used in the outer loop. + // FIXME: This is really heavy handed. It would be a bit better to use an + // SSAUpdater strategy during promotion that was LCSSA aware and reformed + // it as it went. + if (Promoted) + formLCSSARecursively(*L, *DT, LI, SE); + + Changed |= Promoted; } } @@ -294,7 +330,8 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AliasAnalysis *AA, /// bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, TargetLibraryInfo *TLI, Loop *CurLoop, - AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo) { + AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE) { // Verify inputs. assert(N != nullptr && AA != nullptr && LI != nullptr && DT != nullptr && @@ -310,7 +347,8 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, bool Changed = false; const std::vector<DomTreeNode *> &Children = N->getChildren(); for (DomTreeNode *Child : Children) - Changed |= sinkRegion(Child, AA, LI, DT, TLI, CurLoop, CurAST, SafetyInfo); + Changed |= + sinkRegion(Child, AA, LI, DT, TLI, CurLoop, CurAST, SafetyInfo, ORE); // Only need to process the contents of this block if it is not part of a // subloop (which would already have been processed). @@ -337,9 +375,9 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, // operands of the instruction are loop invariant. // if (isNotUsedInLoop(I, CurLoop, SafetyInfo) && - canSinkOrHoistInst(I, AA, DT, TLI, CurLoop, CurAST, SafetyInfo)) { + canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, SafetyInfo, ORE)) { ++II; - Changed |= sink(I, LI, DT, CurLoop, CurAST, SafetyInfo); + Changed |= sink(I, LI, DT, CurLoop, CurAST, SafetyInfo, ORE); } } return Changed; @@ -352,7 +390,8 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, /// bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, TargetLibraryInfo *TLI, Loop *CurLoop, - AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo) { + AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE) { // Verify inputs. assert(N != nullptr && AA != nullptr && LI != nullptr && DT != nullptr && CurLoop != nullptr && CurAST != nullptr && SafetyInfo != nullptr && @@ -382,6 +421,7 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, CurAST->deleteValue(&I); I.eraseFromParent(); } + Changed = true; continue; } @@ -390,16 +430,17 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, // is safe to hoist the instruction. // if (CurLoop->hasLoopInvariantOperands(&I) && - canSinkOrHoistInst(I, AA, DT, TLI, CurLoop, CurAST, SafetyInfo) && + canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, SafetyInfo, ORE) && isSafeToExecuteUnconditionally( - I, DT, CurLoop, SafetyInfo, + I, DT, CurLoop, SafetyInfo, ORE, CurLoop->getLoopPreheader()->getTerminator())) - Changed |= hoist(I, DT, CurLoop, SafetyInfo); + Changed |= hoist(I, DT, CurLoop, SafetyInfo, ORE); } const std::vector<DomTreeNode *> &Children = N->getChildren(); for (DomTreeNode *Child : Children) - Changed |= hoistRegion(Child, AA, LI, DT, TLI, CurLoop, CurAST, SafetyInfo); + Changed |= + hoistRegion(Child, AA, LI, DT, TLI, CurLoop, CurAST, SafetyInfo, ORE); return Changed; } @@ -436,12 +477,10 @@ void llvm::computeLoopSafetyInfo(LoopSafetyInfo *SafetyInfo, Loop *CurLoop) { SafetyInfo->BlockColors = colorEHFunclets(*Fn); } -/// canSinkOrHoistInst - Return true if the hoister and sinker can handle this -/// instruction. -/// -bool canSinkOrHoistInst(Instruction &I, AliasAnalysis *AA, DominatorTree *DT, - TargetLibraryInfo *TLI, Loop *CurLoop, - AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo) { +bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, + Loop *CurLoop, AliasSetTracker *CurAST, + LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE) { // Loads have extra constraints we have to verify before we can hoist them. if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { if (!LI->isUnordered()) @@ -462,7 +501,17 @@ bool canSinkOrHoistInst(Instruction &I, AliasAnalysis *AA, DominatorTree *DT, AAMDNodes AAInfo; LI->getAAMetadata(AAInfo); - return !pointerInvalidatedByLoop(LI->getOperand(0), Size, AAInfo, CurAST); + bool Invalidated = + pointerInvalidatedByLoop(LI->getOperand(0), Size, AAInfo, CurAST); + // Check loop-invariant address because this may also be a sinkable load + // whose address is not necessarily loop-invariant. + if (ORE && Invalidated && CurLoop->isLoopInvariant(LI->getPointerOperand())) + ORE->emit(OptimizationRemarkMissed( + DEBUG_TYPE, "LoadWithLoopInvariantAddressInvalidated", LI) + << "failed to move load with loop-invariant address " + "because the loop may invalidate its value"); + + return !Invalidated; } else if (CallInst *CI = dyn_cast<CallInst>(&I)) { // Don't sink or hoist dbg info; it's legal, but not useful. if (isa<DbgInfoIntrinsic>(I)) @@ -515,6 +564,11 @@ bool canSinkOrHoistInst(Instruction &I, AliasAnalysis *AA, DominatorTree *DT, !isa<InsertValueInst>(I)) return false; + // SafetyInfo is nullptr if we are checking for sinking from preheader to + // loop body. It will be always safe as there is no speculative execution. + if (!SafetyInfo) + return true; + // TODO: Plumb the context instruction through to make hoisting and sinking // more powerful. Hoisting of loads already works due to the special casing // above. @@ -651,8 +705,11 @@ CloneInstructionInExitBlock(Instruction &I, BasicBlock &ExitBlock, PHINode &PN, /// static bool sink(Instruction &I, const LoopInfo *LI, const DominatorTree *DT, const Loop *CurLoop, AliasSetTracker *CurAST, - const LoopSafetyInfo *SafetyInfo) { + const LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE) { DEBUG(dbgs() << "LICM sinking instruction: " << I << "\n"); + ORE->emit(OptimizationRemark(DEBUG_TYPE, "InstSunk", &I) + << "sinking " << ore::NV("Inst", &I)); bool Changed = false; if (isa<LoadInst>(I)) ++NumMovedLoads; @@ -719,10 +776,13 @@ static bool sink(Instruction &I, const LoopInfo *LI, const DominatorTree *DT, /// is safe to hoist, this instruction is called to do the dirty work. /// static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, - const LoopSafetyInfo *SafetyInfo) { + const LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE) { auto *Preheader = CurLoop->getLoopPreheader(); DEBUG(dbgs() << "LICM hoisting to " << Preheader->getName() << ": " << I << "\n"); + ORE->emit(OptimizationRemark(DEBUG_TYPE, "Hoisted", &I) + << "hosting " << ore::NV("Inst", &I)); // Metadata can be dependent on conditions we are hoisting above. // Conservatively strip all metadata on the instruction unless we were @@ -738,6 +798,14 @@ static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, // Move the new node to the Preheader, before its terminator. I.moveBefore(Preheader->getTerminator()); + // Do not retain debug locations when we are moving instructions to different + // basic blocks, because we want to avoid jumpy line tables. Calls, however, + // need to retain their debug locs because they may be inlined. + // FIXME: How do we retain source locations without causing poor debugging + // behavior? + if (!isa<CallInst>(I)) + I.setDebugLoc(DebugLoc()); + if (isa<LoadInst>(I)) ++NumMovedLoads; else if (isa<CallInst>(I)) @@ -749,15 +817,28 @@ static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, /// Only sink or hoist an instruction if it is not a trapping instruction, /// or if the instruction is known not to trap when moved to the preheader. /// or if it is a trapping instruction and is guaranteed to execute. -static bool isSafeToExecuteUnconditionally(const Instruction &Inst, +static bool isSafeToExecuteUnconditionally(Instruction &Inst, const DominatorTree *DT, const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE, const Instruction *CtxI) { if (isSafeToSpeculativelyExecute(&Inst, CtxI, DT)) return true; - return isGuaranteedToExecute(Inst, DT, CurLoop, SafetyInfo); + bool GuaranteedToExecute = + isGuaranteedToExecute(Inst, DT, CurLoop, SafetyInfo); + + if (!GuaranteedToExecute) { + auto *LI = dyn_cast<LoadInst>(&Inst); + if (LI && CurLoop->isLoopInvariant(LI->getPointerOperand())) + ORE->emit(OptimizationRemarkMissed( + DEBUG_TYPE, "LoadWithLoopInvariantAddressCondExecuted", LI) + << "failed to hoist load with loop-invariant address " + "because load is conditionally executed"); + } + + return GuaranteedToExecute; } namespace { @@ -845,7 +926,8 @@ bool llvm::promoteLoopAccessesToScalars( AliasSet &AS, SmallVectorImpl<BasicBlock *> &ExitBlocks, SmallVectorImpl<Instruction *> &InsertPts, PredIteratorCache &PIC, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, - Loop *CurLoop, AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo) { + Loop *CurLoop, AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE) { // Verify inputs. assert(LI != nullptr && DT != nullptr && CurLoop != nullptr && CurAST != nullptr && SafetyInfo != nullptr && @@ -876,23 +958,33 @@ bool llvm::promoteLoopAccessesToScalars( // is not safe, because *P may only be valid to access if 'c' is true. // // The safety property divides into two parts: - // 1) The memory may not be dereferenceable on entry to the loop. In this + // p1) The memory may not be dereferenceable on entry to the loop. In this // case, we can't insert the required load in the preheader. - // 2) The memory model does not allow us to insert a store along any dynamic + // p2) The memory model does not allow us to insert a store along any dynamic // path which did not originally have one. // - // It is safe to promote P if all uses are direct load/stores and if at - // least one is guaranteed to be executed. - bool GuaranteedToExecute = false; - - // It is also safe to promote P if we can prove that speculating a load into - // the preheader is safe (i.e. proving dereferenceability on all - // paths through the loop), and that the memory can be proven thread local - // (so that the memory model requirement doesn't apply.) We first establish - // the former, and then run a capture analysis below to establish the later. - // We can use any access within the alias set to prove dereferenceability + // If at least one store is guaranteed to execute, both properties are + // satisfied, and promotion is legal. + // + // This, however, is not a necessary condition. Even if no store/load is + // guaranteed to execute, we can still establish these properties. + // We can establish (p1) by proving that hoisting the load into the preheader + // is safe (i.e. proving dereferenceability on all paths through the loop). We + // can use any access within the alias set to prove dereferenceability, // since they're all must alias. - bool CanSpeculateLoad = false; + // + // There are two ways establish (p2): + // a) Prove the location is thread-local. In this case the memory model + // requirement does not apply, and stores are safe to insert. + // b) Prove a store dominates every exit block. In this case, if an exit + // blocks is reached, the original dynamic path would have taken us through + // the store, so inserting a store into the exit block is safe. Note that this + // is different from the store being guaranteed to execute. For instance, + // if an exception is thrown on the first iteration of the loop, the original + // store is never executed, but the exit blocks are not executed either. + + bool DereferenceableInPH = false; + bool SafeToInsertStore = false; SmallVector<Instruction *, 64> LoopUses; SmallPtrSet<Value *, 4> PointerMustAliases; @@ -901,15 +993,6 @@ bool llvm::promoteLoopAccessesToScalars( // us to prove better alignment. unsigned Alignment = 1; AAMDNodes AATags; - bool HasDedicatedExits = CurLoop->hasDedicatedExits(); - - // Don't sink stores from loops without dedicated block exits. Exits - // containing indirect branches are not transformed by loop simplify, - // make sure we catch that. An additional load may be generated in the - // preheader for SSA updater, so also avoid sinking when no preheader - // is available. - if (!HasDedicatedExits || !Preheader) - return false; const DataLayout &MDL = Preheader->getModule()->getDataLayout(); @@ -926,7 +1009,6 @@ bool llvm::promoteLoopAccessesToScalars( // Check that all of the pointers in the alias set have the same type. We // cannot (yet) promote a memory location that is loaded and stored in // different sizes. While we are at it, collect alignment and AA info. - bool Changed = false; for (const auto &ASI : AS) { Value *ASIV = ASI.getValue(); PointerMustAliases.insert(ASIV); @@ -935,7 +1017,7 @@ bool llvm::promoteLoopAccessesToScalars( // cannot (yet) promote a memory location that is loaded and stored in // different sizes. if (SomePtr->getType() != ASIV->getType()) - return Changed; + return false; for (User *U : ASIV->users()) { // Ignore instructions that are outside the loop. @@ -945,14 +1027,14 @@ bool llvm::promoteLoopAccessesToScalars( // If there is an non-load/store instruction in the loop, we can't promote // it. - if (const LoadInst *Load = dyn_cast<LoadInst>(UI)) { + if (LoadInst *Load = dyn_cast<LoadInst>(UI)) { assert(!Load->isVolatile() && "AST broken"); if (!Load->isSimple()) - return Changed; + return false; - if (!GuaranteedToExecute && !CanSpeculateLoad) - CanSpeculateLoad = isSafeToExecuteUnconditionally( - *Load, DT, CurLoop, SafetyInfo, Preheader->getTerminator()); + if (!DereferenceableInPH) + DereferenceableInPH = isSafeToExecuteUnconditionally( + *Load, DT, CurLoop, SafetyInfo, ORE, Preheader->getTerminator()); } else if (const StoreInst *Store = dyn_cast<StoreInst>(UI)) { // Stores *of* the pointer are not interesting, only stores *to* the // pointer. @@ -960,35 +1042,47 @@ bool llvm::promoteLoopAccessesToScalars( continue; assert(!Store->isVolatile() && "AST broken"); if (!Store->isSimple()) - return Changed; - - // Note that we only check GuaranteedToExecute inside the store case - // so that we do not introduce stores where they did not exist before - // (which would break the LLVM concurrency model). + return false; - // If the alignment of this instruction allows us to specify a more - // restrictive (and performant) alignment and if we are sure this - // instruction will be executed, update the alignment. - // Larger is better, with the exception of 0 being the best alignment. + // If the store is guaranteed to execute, both properties are satisfied. + // We may want to check if a store is guaranteed to execute even if we + // already know that promotion is safe, since it may have higher + // alignment than any other guaranteed stores, in which case we can + // raise the alignment on the promoted store. unsigned InstAlignment = Store->getAlignment(); - if ((InstAlignment > Alignment || InstAlignment == 0) && - Alignment != 0) { + if (!InstAlignment) + InstAlignment = + MDL.getABITypeAlignment(Store->getValueOperand()->getType()); + + if (!DereferenceableInPH || !SafeToInsertStore || + (InstAlignment > Alignment)) { if (isGuaranteedToExecute(*UI, DT, CurLoop, SafetyInfo)) { - GuaranteedToExecute = true; - Alignment = InstAlignment; + DereferenceableInPH = true; + SafeToInsertStore = true; + Alignment = std::max(Alignment, InstAlignment); } - } else if (!GuaranteedToExecute) { - GuaranteedToExecute = - isGuaranteedToExecute(*UI, DT, CurLoop, SafetyInfo); } - if (!GuaranteedToExecute && !CanSpeculateLoad) { - CanSpeculateLoad = isDereferenceableAndAlignedPointer( + // If a store dominates all exit blocks, it is safe to sink. + // As explained above, if an exit block was executed, a dominating + // store must have been been executed at least once, so we are not + // introducing stores on paths that did not have them. + // Note that this only looks at explicit exit blocks. If we ever + // start sinking stores into unwind edges (see above), this will break. + if (!SafeToInsertStore) + SafeToInsertStore = llvm::all_of(ExitBlocks, [&](BasicBlock *Exit) { + return DT->dominates(Store->getParent(), Exit); + }); + + // If the store is not guaranteed to execute, we may still get + // deref info through it. + if (!DereferenceableInPH) { + DereferenceableInPH = isDereferenceableAndAlignedPointer( Store->getPointerOperand(), Store->getAlignment(), MDL, Preheader->getTerminator(), DT); } } else - return Changed; // Not a load or store. + return false; // Not a load or store. // Merge the AA tags. if (LoopUses.empty()) { @@ -1002,38 +1096,32 @@ bool llvm::promoteLoopAccessesToScalars( } } - // Check legality per comment above. Otherwise, we can't promote. - bool PromotionIsLegal = GuaranteedToExecute; - if (!PromotionIsLegal && CanSpeculateLoad) { - // If this is a thread local location, then we can insert stores along - // paths which originally didn't have them without violating the memory - // model. - Value *Object = GetUnderlyingObject(SomePtr, MDL); - PromotionIsLegal = - isAllocLikeFn(Object, TLI) && !PointerMayBeCaptured(Object, true, true); - } - if (!PromotionIsLegal) - return Changed; - // Figure out the loop exits and their insertion points, if this is the - // first promotion. - if (ExitBlocks.empty()) { - CurLoop->getUniqueExitBlocks(ExitBlocks); - InsertPts.clear(); - InsertPts.reserve(ExitBlocks.size()); - for (BasicBlock *ExitBlock : ExitBlocks) - InsertPts.push_back(&*ExitBlock->getFirstInsertionPt()); + // If we couldn't prove we can hoist the load, bail. + if (!DereferenceableInPH) + return false; + + // We know we can hoist the load, but don't have a guaranteed store. + // Check whether the location is thread-local. If it is, then we can insert + // stores along paths which originally didn't have them without violating the + // memory model. + if (!SafeToInsertStore) { + Value *Object = GetUnderlyingObject(SomePtr, MDL); + SafeToInsertStore = + (isAllocLikeFn(Object, TLI) || isa<AllocaInst>(Object)) && + !PointerMayBeCaptured(Object, true, true); } - // Can't insert into a catchswitch. - for (BasicBlock *ExitBlock : ExitBlocks) - if (isa<CatchSwitchInst>(ExitBlock->getTerminator())) - return Changed; + // If we've still failed to prove we can sink the store, give up. + if (!SafeToInsertStore) + return false; // Otherwise, this is safe to promote, lets do it! DEBUG(dbgs() << "LICM: Promoting value stored to in loop: " << *SomePtr << '\n'); - Changed = true; + ORE->emit( + OptimizationRemark(DEBUG_TYPE, "PromoteLoopAccessesToScalar", LoopUses[0]) + << "Moving accesses to memory location out of the loop"); ++NumPromoted; // Grab a debug location for the inserted loads/stores; given that the @@ -1066,13 +1154,13 @@ bool llvm::promoteLoopAccessesToScalars( if (PreheaderLoad->use_empty()) PreheaderLoad->eraseFromParent(); - return Changed; + return true; } /// Returns an owning pointer to an alias set which incorporates aliasing info /// from L and all subloops of L. -/// FIXME: In new pass manager, there is no helper functions to handle loop -/// analysis such as cloneBasicBlockAnalysis. So the AST needs to be recompute +/// FIXME: In new pass manager, there is no helper function to handle loop +/// analysis such as cloneBasicBlockAnalysis, so the AST needs to be recomputed /// from scratch for every loop. Hook up with the helper functions when /// available in the new pass manager to avoid redundant computation. AliasSetTracker * @@ -1108,10 +1196,7 @@ LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI, auto mergeLoop = [&](Loop *L) { // Loop over the body of this loop, looking for calls, invokes, and stores. - // Because subloops have already been incorporated into AST, we skip blocks - // in subloops. for (BasicBlock *BB : L->blocks()) - if (LI->getLoopFor(BB) == L) // Ignore blocks in subloops. CurAST->add(*BB); // Incorporate the specified basic block }; diff --git a/contrib/llvm/lib/Transforms/Scalar/LoadCombine.cpp b/contrib/llvm/lib/Transforms/Scalar/LoadCombine.cpp index dfe51a4..389f1c5 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoadCombine.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoadCombine.cpp @@ -44,9 +44,6 @@ struct PointerOffsetPair { }; struct LoadPOPPair { - LoadPOPPair() = default; - LoadPOPPair(LoadInst *L, PointerOffsetPair P, unsigned O) - : Load(L), POP(P), InsertOrder(O) {} LoadInst *Load; PointerOffsetPair POP; /// \brief The new load needs to be created before the first load in IR order. @@ -71,7 +68,7 @@ public: AU.addPreserved<GlobalsAAWrapperPass>(); } - const char *getPassName() const override { return LDCOMBINE_NAME; } + StringRef getPassName() const override { return LDCOMBINE_NAME; } static char ID; typedef IRBuilder<TargetFolder> BuilderTy; @@ -264,7 +261,7 @@ bool LoadCombine::runOnBasicBlock(BasicBlock &BB) { auto POP = getPointerOffsetPair(*LI); if (!POP.Pointer) continue; - LoadMap[POP.Pointer].push_back(LoadPOPPair(LI, POP, Index++)); + LoadMap[POP.Pointer].push_back({LI, std::move(POP), Index++}); AST.add(LI); } if (combineLoads(LoadMap)) diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp new file mode 100644 index 0000000..a64c991 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp @@ -0,0 +1,25 @@ +//===- LoopAccessAnalysisPrinter.cpp - Loop Access Analysis Printer --------==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopAccessAnalysisPrinter.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" +using namespace llvm; + +#define DEBUG_TYPE "loop-accesses" + +PreservedAnalyses +LoopAccessInfoPrinterPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, LPMUpdater &) { + Function &F = *L.getHeader()->getParent(); + auto &LAI = AM.getResult<LoopAccessAnalysis>(L, AR); + OS << "Loop access info in function '" << F.getName() << "':\n"; + OS.indent(2) << L.getHeader()->getName() << ":\n"; + LAI.print(OS, 4); + return PreservedAnalyses::all(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp index 66b59d2..d09af32 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -11,14 +11,16 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/LoopDataPrefetch.h" + #define DEBUG_TYPE "loop-data-prefetch" -#include "llvm/Transforms/Scalar.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" @@ -26,13 +28,13 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" -#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" @@ -59,77 +61,89 @@ static cl::opt<unsigned> MaxPrefetchIterationsAhead( STATISTIC(NumPrefetches, "Number of prefetches inserted"); -namespace llvm { - void initializeLoopDataPrefetchPass(PassRegistry&); -} - namespace { - class LoopDataPrefetch : public FunctionPass { - public: - static char ID; // Pass ID, replacement for typeid - LoopDataPrefetch() : FunctionPass(ID) { - initializeLoopDataPrefetchPass(*PassRegistry::getPassRegistry()); - } +/// Loop prefetch implementation class. +class LoopDataPrefetch { +public: + LoopDataPrefetch(AssumptionCache *AC, LoopInfo *LI, ScalarEvolution *SE, + const TargetTransformInfo *TTI, + OptimizationRemarkEmitter *ORE) + : AC(AC), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {} - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AssumptionCacheTracker>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); - // FIXME: For some reason, preserving SE here breaks LSR (even if - // this pass changes nothing). - // AU.addPreserved<ScalarEvolutionWrapperPass>(); - AU.addRequired<TargetTransformInfoWrapperPass>(); - } + bool run(); - bool runOnFunction(Function &F) override; +private: + bool runOnLoop(Loop *L); - private: - bool runOnLoop(Loop *L); + /// \brief Check if the the stride of the accesses is large enough to + /// warrant a prefetch. + bool isStrideLargeEnough(const SCEVAddRecExpr *AR); - /// \brief Check if the the stride of the accesses is large enough to - /// warrant a prefetch. - bool isStrideLargeEnough(const SCEVAddRecExpr *AR); + unsigned getMinPrefetchStride() { + if (MinPrefetchStride.getNumOccurrences() > 0) + return MinPrefetchStride; + return TTI->getMinPrefetchStride(); + } - unsigned getMinPrefetchStride() { - if (MinPrefetchStride.getNumOccurrences() > 0) - return MinPrefetchStride; - return TTI->getMinPrefetchStride(); - } + unsigned getPrefetchDistance() { + if (PrefetchDistance.getNumOccurrences() > 0) + return PrefetchDistance; + return TTI->getPrefetchDistance(); + } - unsigned getPrefetchDistance() { - if (PrefetchDistance.getNumOccurrences() > 0) - return PrefetchDistance; - return TTI->getPrefetchDistance(); - } + unsigned getMaxPrefetchIterationsAhead() { + if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0) + return MaxPrefetchIterationsAhead; + return TTI->getMaxPrefetchIterationsAhead(); + } - unsigned getMaxPrefetchIterationsAhead() { - if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0) - return MaxPrefetchIterationsAhead; - return TTI->getMaxPrefetchIterationsAhead(); - } + AssumptionCache *AC; + LoopInfo *LI; + ScalarEvolution *SE; + const TargetTransformInfo *TTI; + OptimizationRemarkEmitter *ORE; +}; + +/// Legacy class for inserting loop data prefetches. +class LoopDataPrefetchLegacyPass : public FunctionPass { +public: + static char ID; // Pass ID, replacement for typeid + LoopDataPrefetchLegacyPass() : FunctionPass(ID) { + initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + // FIXME: For some reason, preserving SE here breaks LSR (even if + // this pass changes nothing). + // AU.addPreserved<ScalarEvolutionWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + } - AssumptionCache *AC; - LoopInfo *LI; - ScalarEvolution *SE; - const TargetTransformInfo *TTI; - const DataLayout *DL; + bool runOnFunction(Function &F) override; }; } -char LoopDataPrefetch::ID = 0; -INITIALIZE_PASS_BEGIN(LoopDataPrefetch, "loop-data-prefetch", +char LoopDataPrefetchLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass, "loop-data-prefetch", "Loop Data Prefetch", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_END(LoopDataPrefetch, "loop-data-prefetch", +INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass, "loop-data-prefetch", "Loop Data Prefetch", false, false) -FunctionPass *llvm::createLoopDataPrefetchPass() { return new LoopDataPrefetch(); } +FunctionPass *llvm::createLoopDataPrefetchPass() { + return new LoopDataPrefetchLegacyPass(); +} bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) { unsigned TargetMinStride = getMinPrefetchStride(); @@ -147,16 +161,46 @@ bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) { return TargetMinStride <= AbsStride; } -bool LoopDataPrefetch::runOnFunction(Function &F) { +PreservedAnalyses LoopDataPrefetchPass::run(Function &F, + FunctionAnalysisManager &AM) { + LoopInfo *LI = &AM.getResult<LoopAnalysis>(F); + ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F); + AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F); + OptimizationRemarkEmitter *ORE = + &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F); + + LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE); + bool Changed = LDP.run(); + + if (Changed) { + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<LoopAnalysis>(); + return PA; + } + + return PreservedAnalyses::all(); +} + +bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) { if (skipFunction(F)) return false; - LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - DL = &F.getParent()->getDataLayout(); - AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + AssumptionCache *AC = + &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + OptimizationRemarkEmitter *ORE = + &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); + const TargetTransformInfo *TTI = + &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + + LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE); + return LDP.run(); +} +bool LoopDataPrefetch::run() { // If PrefetchDistance is not set, don't run the pass. This gives an // opportunity for targets to run this pass for selected subtargets only // (whose TTI sets PrefetchDistance). @@ -185,19 +229,16 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { // Calculate the number of iterations ahead to prefetch CodeMetrics Metrics; - for (Loop::block_iterator I = L->block_begin(), IE = L->block_end(); - I != IE; ++I) { - + for (const auto BB : L->blocks()) { // If the loop already has prefetches, then assume that the user knows // what they are doing and don't add any more. - for (BasicBlock::iterator J = (*I)->begin(), JE = (*I)->end(); - J != JE; ++J) - if (CallInst *CI = dyn_cast<CallInst>(J)) + for (auto &I : *BB) + if (CallInst *CI = dyn_cast<CallInst>(&I)) if (Function *F = CI->getCalledFunction()) if (F->getIntrinsicID() == Intrinsic::prefetch) return MadeChange; - Metrics.analyzeBasicBlock(*I, *TTI, EphValues); + Metrics.analyzeBasicBlock(BB, *TTI, EphValues); } unsigned LoopSize = Metrics.NumInsts; if (!LoopSize) @@ -210,23 +251,20 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { if (ItersAhead > getMaxPrefetchIterationsAhead()) return MadeChange; - Function *F = L->getHeader()->getParent(); DEBUG(dbgs() << "Prefetching " << ItersAhead << " iterations ahead (loop size: " << LoopSize << ") in " - << F->getName() << ": " << *L); + << L->getHeader()->getParent()->getName() << ": " << *L); SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>, 16> PrefLoads; - for (Loop::block_iterator I = L->block_begin(), IE = L->block_end(); - I != IE; ++I) { - for (BasicBlock::iterator J = (*I)->begin(), JE = (*I)->end(); - J != JE; ++J) { + for (const auto BB : L->blocks()) { + for (auto &I : *BB) { Value *PtrValue; Instruction *MemI; - if (LoadInst *LMemI = dyn_cast<LoadInst>(J)) { + if (LoadInst *LMemI = dyn_cast<LoadInst>(&I)) { MemI = LMemI; PtrValue = LMemI->getPointerOperand(); - } else if (StoreInst *SMemI = dyn_cast<StoreInst>(J)) { + } else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) { if (!PrefetchWrites) continue; MemI = SMemI; PtrValue = SMemI->getPointerOperand(); @@ -275,13 +313,13 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { PrefLoads.push_back(std::make_pair(MemI, LSCEVAddRec)); - Type *I8Ptr = Type::getInt8PtrTy((*I)->getContext(), PtrAddrSpace); - SCEVExpander SCEVE(*SE, J->getModule()->getDataLayout(), "prefaddr"); + Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace); + SCEVExpander SCEVE(*SE, I.getModule()->getDataLayout(), "prefaddr"); Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, MemI); IRBuilder<> Builder(MemI); - Module *M = (*I)->getParent()->getParent(); - Type *I32 = Type::getInt32Ty((*I)->getContext()); + Module *M = BB->getParent()->getParent(); + Type *I32 = Type::getInt32Ty(BB->getContext()); Value *PrefetchFunc = Intrinsic::getDeclaration(M, Intrinsic::prefetch); Builder.CreateCall( PrefetchFunc, @@ -291,9 +329,8 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { ++NumPrefetches; DEBUG(dbgs() << " Access: " << *PtrValue << ", SCEV: " << *LSCEV << "\n"); - emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, - MemI->getDebugLoc(), "prefetched memory access"); - + ORE->emit(OptimizationRemark(DEBUG_TYPE, "Prefetched", MemI) + << "prefetched memory access"); MadeChange = true; } diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopDeletion.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopDeletion.cpp index 19b2f89..cca75a3 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopDeletion.cpp @@ -19,9 +19,9 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/LoopPassManager.h" #include "llvm/IR/Dominators.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; @@ -215,15 +215,10 @@ bool LoopDeletionPass::runImpl(Loop *L, DominatorTree &DT, ScalarEvolution &SE, return Changed; } -PreservedAnalyses LoopDeletionPass::run(Loop &L, AnalysisManager<Loop> &AM) { - auto &FAM = AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); - Function *F = L.getHeader()->getParent(); - - auto &DT = *FAM.getCachedResult<DominatorTreeAnalysis>(*F); - auto &SE = *FAM.getCachedResult<ScalarEvolutionAnalysis>(*F); - auto &LI = *FAM.getCachedResult<LoopAnalysis>(*F); - - bool Changed = runImpl(&L, DT, SE, LI); +PreservedAnalyses LoopDeletionPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + bool Changed = runImpl(&L, AR.DT, AR.SE, AR.LI); if (!Changed) return PreservedAnalyses::all(); diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopDistribute.cpp index 7eca28e..19716b2 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -28,15 +28,16 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/LoopPassManager.h" #include "llvm/Analysis/OptimizationDiagnosticInfo.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -72,11 +73,10 @@ static cl::opt<unsigned> PragmaDistributeSCEVCheckThreshold( "The maximum number of SCEV checks allowed for Loop " "Distribution for loop marked with #pragma loop distribute(enable)")); -// Note that the initial value for this depends on whether the pass is invoked -// directly or from the optimization pipeline. static cl::opt<bool> EnableLoopDistribute( "enable-loop-distribute", cl::Hidden, - cl::desc("Enable the new, experimental LoopDistribution Pass")); + cl::desc("Enable the new, experimental LoopDistribution Pass"), + cl::init(false)); STATISTIC(NumLoopsDistributed, "Number of loops distributed"); @@ -605,11 +605,13 @@ public: DEBUG(dbgs() << "\nLDist: In \"" << L->getHeader()->getParent()->getName() << "\" checking " << *L << "\n"); - BasicBlock *PH = L->getLoopPreheader(); - if (!PH) - return fail("no preheader"); if (!L->getExitBlock()) - return fail("multiple exit blocks"); + return fail("MultipleExitBlocks", "multiple exit blocks"); + if (!L->isLoopSimplifyForm()) + return fail("NotLoopSimplifyForm", + "loop is not in loop-simplify form"); + + BasicBlock *PH = L->getLoopPreheader(); // LAA will check that we only have a single exiting block. LAI = &GetLAA(*L); @@ -617,11 +619,12 @@ public: // Currently, we only distribute to isolate the part of the loop with // dependence cycles to enable partial vectorization. if (LAI->canVectorizeMemory()) - return fail("memory operations are safe for vectorization"); + return fail("MemOpsCanBeVectorized", + "memory operations are safe for vectorization"); auto *Dependences = LAI->getDepChecker().getDependences(); if (!Dependences || Dependences->empty()) - return fail("no unsafe dependences to isolate"); + return fail("NoUnsafeDeps", "no unsafe dependences to isolate"); InstPartitionContainer Partitions(L, LI, DT); @@ -674,14 +677,16 @@ public: DEBUG(dbgs() << "Seeded partitions:\n" << Partitions); if (Partitions.getSize() < 2) - return fail("cannot isolate unsafe dependencies"); + return fail("CantIsolateUnsafeDeps", + "cannot isolate unsafe dependencies"); // Run the merge heuristics: Merge non-cyclic adjacent partitions since we // should be able to vectorize these together. Partitions.mergeBeforePopulating(); DEBUG(dbgs() << "\nMerged partitions:\n" << Partitions); if (Partitions.getSize() < 2) - return fail("cannot isolate unsafe dependencies"); + return fail("CantIsolateUnsafeDeps", + "cannot isolate unsafe dependencies"); // Now, populate the partitions with non-memory operations. Partitions.populateUsedSet(); @@ -693,7 +698,8 @@ public: DEBUG(dbgs() << "\nPartitions merged to ensure unique loads:\n" << Partitions); if (Partitions.getSize() < 2) - return fail("cannot isolate unsafe dependencies"); + return fail("CantIsolateUnsafeDeps", + "cannot isolate unsafe dependencies"); } // Don't distribute the loop if we need too many SCEV run-time checks. @@ -701,7 +707,8 @@ public: if (Pred.getComplexity() > (IsForced.getValueOr(false) ? PragmaDistributeSCEVCheckThreshold : DistributeSCEVCheckThreshold)) - return fail("too many SCEV run-time checks needed.\n"); + return fail("TooManySCEVRuntimeChecks", + "too many SCEV run-time checks needed.\n"); DEBUG(dbgs() << "\nDistributing loop: " << *L << "\n"); // We're done forming the partitions set up the reverse mapping from @@ -742,36 +749,38 @@ public: DEBUG(Partitions.printBlocks()); if (LDistVerify) { - LI->verify(); + LI->verify(*DT); DT->verifyDomTree(); } ++NumLoopsDistributed; // Report the success. - emitOptimizationRemark(F->getContext(), LDIST_NAME, *F, L->getStartLoc(), - "distributed loop"); + ORE->emit(OptimizationRemark(LDIST_NAME, "Distribute", L->getStartLoc(), + L->getHeader()) + << "distributed loop"); return true; } /// \brief Provide diagnostics then \return with false. - bool fail(llvm::StringRef Message) { + bool fail(StringRef RemarkName, StringRef Message) { LLVMContext &Ctx = F->getContext(); bool Forced = isForced().getValueOr(false); DEBUG(dbgs() << "Skipping; " << Message << "\n"); // With Rpass-missed report that distribution failed. - ORE->emitOptimizationRemarkMissed( - LDIST_NAME, L, - "loop not distributed: use -Rpass-analysis=loop-distribute for more " - "info"); + ORE->emit( + OptimizationRemarkMissed(LDIST_NAME, "NotDistributed", L->getStartLoc(), + L->getHeader()) + << "loop not distributed: use -Rpass-analysis=loop-distribute for more " + "info"); // With Rpass-analysis report why. This is on by default if distribution // was requested explicitly. - emitOptimizationRemarkAnalysis( - Ctx, Forced ? DiagnosticInfoOptimizationRemarkAnalysis::AlwaysPrint - : LDIST_NAME, - *F, L->getStartLoc(), Twine("loop not distributed: ") + Message); + ORE->emit(OptimizationRemarkAnalysis( + Forced ? OptimizationRemarkAnalysis::AlwaysPrint : LDIST_NAME, + RemarkName, L->getStartLoc(), L->getHeader()) + << "loop not distributed: " << Message); // Also issue a warning if distribution was requested explicitly but it // failed. @@ -865,8 +874,7 @@ private: /// Shared implementation between new and old PMs. static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT, ScalarEvolution *SE, OptimizationRemarkEmitter *ORE, - std::function<const LoopAccessInfo &(Loop &)> &GetLAA, - bool ProcessAllLoops) { + std::function<const LoopAccessInfo &(Loop &)> &GetLAA) { // Build up a worklist of inner-loops to vectorize. This is necessary as the // act of distributing a loop creates new loops and can invalidate iterators // across the loops. @@ -885,7 +893,7 @@ static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT, // If distribution was forced for the specific loop to be // enabled/disabled, follow that. Otherwise use the global flag. - if (LDL.isForced().getValueOr(ProcessAllLoops)) + if (LDL.isForced().getValueOr(EnableLoopDistribute)) Changed |= LDL.processLoop(GetLAA); } @@ -896,15 +904,8 @@ static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT, /// \brief The pass class. class LoopDistributeLegacy : public FunctionPass { public: - /// \p ProcessAllLoopsByDefault specifies whether loop distribution should be - /// performed by default. Pass -enable-loop-distribute={0,1} overrides this - /// default. We use this to keep LoopDistribution off by default when invoked - /// from the optimization pipeline but on when invoked explicitly from opt. - LoopDistributeLegacy(bool ProcessAllLoopsByDefault = true) - : FunctionPass(ID), ProcessAllLoops(ProcessAllLoopsByDefault) { + LoopDistributeLegacy() : FunctionPass(ID) { // The default is set by the caller. - if (EnableLoopDistribute.getNumOccurrences() > 0) - ProcessAllLoops = EnableLoopDistribute; initializeLoopDistributeLegacyPass(*PassRegistry::getPassRegistry()); } @@ -920,7 +921,7 @@ public: std::function<const LoopAccessInfo &(Loop &)> GetLAA = [&](Loop &L) -> const LoopAccessInfo & { return LAA->getInfo(&L); }; - return runImpl(F, LI, DT, SE, ORE, GetLAA, ProcessAllLoops); + return runImpl(F, LI, DT, SE, ORE, GetLAA); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -931,48 +932,46 @@ public: AU.addRequired<DominatorTreeWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); } static char ID; - -private: - /// \brief Whether distribution should be on in this function. The per-loop - /// pragma can override this. - bool ProcessAllLoops; }; } // anonymous namespace PreservedAnalyses LoopDistributePass::run(Function &F, FunctionAnalysisManager &AM) { - // FIXME: This does not currently match the behavior from the old PM. - // ProcessAllLoops with the old PM defaults to true when invoked from opt and - // false when invoked from the optimization pipeline. - bool ProcessAllLoops = false; - if (EnableLoopDistribute.getNumOccurrences() > 0) - ProcessAllLoops = EnableLoopDistribute; - auto &LI = AM.getResult<LoopAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + // We don't directly need these analyses but they're required for loop + // analyses so provide them below. + auto &AA = AM.getResult<AAManager>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); std::function<const LoopAccessInfo &(Loop &)> GetLAA = [&](Loop &L) -> const LoopAccessInfo & { - return LAM.getResult<LoopAccessAnalysis>(L); + LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, TLI, TTI}; + return LAM.getResult<LoopAccessAnalysis>(L, AR); }; - bool Changed = runImpl(F, &LI, &DT, &SE, &ORE, GetLAA, ProcessAllLoops); + bool Changed = runImpl(F, &LI, &DT, &SE, &ORE, GetLAA); if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserve<LoopAnalysis>(); PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<GlobalsAA>(); return PA; } char LoopDistributeLegacy::ID; -static const char ldist_name[] = "Loop Distribition"; +static const char ldist_name[] = "Loop Distribution"; INITIALIZE_PASS_BEGIN(LoopDistributeLegacy, LDIST_NAME, ldist_name, false, false) @@ -984,7 +983,5 @@ INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(LoopDistributeLegacy, LDIST_NAME, ldist_name, false, false) namespace llvm { -FunctionPass *createLoopDistributePass(bool ProcessAllLoopsByDefault) { - return new LoopDistributeLegacy(ProcessAllLoopsByDefault); -} +FunctionPass *createLoopDistributePass() { return new LoopDistributeLegacy(); } } diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 1468676..5fec51c 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -11,6 +11,12 @@ // non-loop form. In cases that this kicks in, it can be a significant // performance win. // +// If compiling for code size we avoid idiom recognition if the resulting +// code could be larger than the code for the original loop. One way this could +// happen is if the loop is not removable after idiom recognition due to the +// presence of non-idiom instructions. The initial implementation of the +// heuristics applies to idioms in multi-block loops. +// //===----------------------------------------------------------------------===// // // TODO List: @@ -40,7 +46,6 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/LoopPassManager.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" @@ -55,6 +60,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -65,6 +71,12 @@ using namespace llvm; STATISTIC(NumMemSet, "Number of memset's formed from loop stores"); STATISTIC(NumMemCpy, "Number of memcpy's formed from loop load+stores"); +static cl::opt<bool> UseLIRCodeSizeHeurs( + "use-lir-code-size-heurs", + cl::desc("Use loop idiom recognition code size heuristics when compiling" + "with -Os/-Oz"), + cl::init(true), cl::Hidden); + namespace { class LoopIdiomRecognize { @@ -76,6 +88,7 @@ class LoopIdiomRecognize { TargetLibraryInfo *TLI; const TargetTransformInfo *TTI; const DataLayout *DL; + bool ApplyCodeSizeHeuristics; public: explicit LoopIdiomRecognize(AliasAnalysis *AA, DominatorTree *DT, @@ -117,8 +130,10 @@ private: Instruction *TheStore, SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev, const SCEV *BECount, - bool NegStride); + bool NegStride, bool IsLoopMemset = false); bool processLoopStoreOfLoopLoad(StoreInst *SI, const SCEV *BECount); + bool avoidLIRForMultiBlockLoop(bool IsMemset = false, + bool IsLoopMemset = false); /// @} /// \name Noncountable Loop Idiom Handling @@ -171,24 +186,12 @@ public: }; } // End anonymous namespace. -PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, - AnalysisManager<Loop> &AM) { - const auto &FAM = - AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); - Function *F = L.getHeader()->getParent(); - - // Use getCachedResult because Loop pass cannot trigger a function analysis. - auto *AA = FAM.getCachedResult<AAManager>(*F); - auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(*F); - auto *LI = FAM.getCachedResult<LoopAnalysis>(*F); - auto *SE = FAM.getCachedResult<ScalarEvolutionAnalysis>(*F); - auto *TLI = FAM.getCachedResult<TargetLibraryAnalysis>(*F); - const auto *TTI = FAM.getCachedResult<TargetIRAnalysis>(*F); +PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { const auto *DL = &L.getHeader()->getModule()->getDataLayout(); - assert((AA && DT && LI && SE && TLI && TTI && DL) && - "Analyses for Loop Idiom Recognition not available"); - LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, DL); + LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &AR.SE, &AR.TLI, &AR.TTI, DL); if (!LIR.runOnLoop(&L)) return PreservedAnalyses::all(); @@ -229,6 +232,10 @@ bool LoopIdiomRecognize::runOnLoop(Loop *L) { if (Name == "memset" || Name == "memcpy") return false; + // Determine if code size heuristics need to be applied. + ApplyCodeSizeHeuristics = + L->getHeader()->getParent()->optForSize() && UseLIRCodeSizeHeurs; + HasMemset = TLI->has(LibFunc::memset); HasMemsetPattern = TLI->has(LibFunc::memset_pattern16); HasMemcpy = TLI->has(LibFunc::memcpy); @@ -689,7 +696,7 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, bool NegStride = SizeInBytes == -Stride; return processLoopStridedStore(Pointer, (unsigned)SizeInBytes, MSI->getAlignment(), SplatValue, MSI, MSIs, Ev, - BECount, NegStride); + BECount, NegStride, /*IsLoopMemset=*/true); } /// mayLoopAccessLocation - Return true if the specified loop might access the @@ -745,7 +752,7 @@ bool LoopIdiomRecognize::processLoopStridedStore( Value *DestPtr, unsigned StoreSize, unsigned StoreAlignment, Value *StoredVal, Instruction *TheStore, SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev, - const SCEV *BECount, bool NegStride) { + const SCEV *BECount, bool NegStride, bool IsLoopMemset) { Value *SplatValue = isBytewiseValue(StoredVal); Constant *PatternValue = nullptr; @@ -786,6 +793,9 @@ bool LoopIdiomRecognize::processLoopStridedStore( return false; } + if (avoidLIRForMultiBlockLoop(/*IsMemset=*/true, IsLoopMemset)) + return false; + // Okay, everything looks good, insert the memset. // The # stored bytes is (BECount+1)*Size. Expand the trip count out to @@ -917,6 +927,9 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, return false; } + if (avoidLIRForMultiBlockLoop()) + return false; + // Okay, everything is safe, we can transform this! // The # stored bytes is (BECount+1)*Size. Expand the trip count out to @@ -948,6 +961,23 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, return true; } +// When compiling for codesize we avoid idiom recognition for a multi-block loop +// unless it is a loop_memset idiom or a memset/memcpy idiom in a nested loop. +// +bool LoopIdiomRecognize::avoidLIRForMultiBlockLoop(bool IsMemset, + bool IsLoopMemset) { + if (ApplyCodeSizeHeuristics && CurLoop->getNumBlocks() > 1) { + if (!CurLoop->getParentLoop() && (!IsMemset || !IsLoopMemset)) { + DEBUG(dbgs() << " " << CurLoop->getHeader()->getParent()->getName() + << " : LIR " << (IsMemset ? "Memset" : "Memcpy") + << " avoided: multi-block top-level loop\n"); + return true; + } + } + + return false; +} + bool LoopIdiomRecognize::runOnNoncountableLoop() { return recognizePopcount(); } @@ -955,7 +985,7 @@ bool LoopIdiomRecognize::runOnNoncountableLoop() { /// Check if the given conditional branch is based on the comparison between /// a variable and zero, and if the variable is non-zero, the control yields to /// the loop entry. If the branch matches the behavior, the variable involved -/// in the comparion is returned. This function will be called to see if the +/// in the comparison is returned. This function will be called to see if the /// precondition and postcondition of the loop are in desirable form. static Value *matchCondition(BranchInst *BI, BasicBlock *LoopEntry) { if (!BI || !BI->isConditional()) @@ -1139,9 +1169,7 @@ bool LoopIdiomRecognize::recognizePopcount() { // It should have a preheader containing nothing but an unconditional branch. BasicBlock *PH = CurLoop->getLoopPreheader(); - if (!PH) - return false; - if (&PH->front() != PH->getTerminator()) + if (!PH || &PH->front() != PH->getTerminator()) return false; auto *EntryBI = dyn_cast<BranchInst>(PH->getTerminator()); if (!EntryBI || EntryBI->isConditional()) diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp index 629cb87..69102d1 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -18,7 +18,6 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/LoopPassManager.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/DataLayout.h" @@ -26,6 +25,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; @@ -183,20 +183,10 @@ public: }; } -PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, - AnalysisManager<Loop> &AM) { - const auto &FAM = - AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); - Function *F = L.getHeader()->getParent(); - - // Use getCachedResult because Loop pass cannot trigger a function analysis. - auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(*F); - auto *LI = FAM.getCachedResult<LoopAnalysis>(*F); - auto *AC = FAM.getCachedResult<AssumptionAnalysis>(*F); - const auto *TLI = FAM.getCachedResult<TargetLibraryAnalysis>(*F); - assert((LI && AC && TLI) && "Analyses for Loop Inst Simplify not available"); - - if (!SimplifyLoopInst(&L, DT, LI, AC, TLI)) +PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + if (!SimplifyLoopInst(&L, &AR.DT, &AR.LI, &AR.AC, &AR.TLI)) return PreservedAnalyses::all(); return getLoopPassPreservedAnalyses(); diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopInterchange.cpp index 9241ec3..e9f84ed 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -44,6 +44,10 @@ using namespace llvm; #define DEBUG_TYPE "loop-interchange" +static cl::opt<int> LoopInterchangeCostThreshold( + "loop-interchange-threshold", cl::init(0), cl::Hidden, + cl::desc("Interchange if you gain more than this number")); + namespace { typedef SmallVector<Loop *, 8> LoopVector; @@ -75,30 +79,23 @@ static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, typedef SmallVector<Value *, 16> ValueVector; ValueVector MemInstr; - if (Level > MaxLoopNestDepth) { - DEBUG(dbgs() << "Cannot handle loops of depth greater than " - << MaxLoopNestDepth << "\n"); - return false; - } - // For each block. for (Loop::block_iterator BB = L->block_begin(), BE = L->block_end(); BB != BE; ++BB) { // Scan the BB and collect legal loads and stores. for (BasicBlock::iterator I = (*BB)->begin(), E = (*BB)->end(); I != E; ++I) { - Instruction *Ins = dyn_cast<Instruction>(I); - if (!Ins) - return false; - LoadInst *Ld = dyn_cast<LoadInst>(I); - StoreInst *St = dyn_cast<StoreInst>(I); - if (!St && !Ld) - continue; - if (Ld && !Ld->isSimple()) - return false; - if (St && !St->isSimple()) + if (!isa<Instruction>(I)) return false; - MemInstr.push_back(&*I); + if (LoadInst *Ld = dyn_cast<LoadInst>(I)) { + if (!Ld->isSimple()) + return false; + MemInstr.push_back(&*I); + } else if (StoreInst *St = dyn_cast<StoreInst>(I)) { + if (!St->isSimple()) + return false; + MemInstr.push_back(&*I); + } } } @@ -110,66 +107,63 @@ static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, for (I = MemInstr.begin(), IE = MemInstr.end(); I != IE; ++I) { for (J = I, JE = MemInstr.end(); J != JE; ++J) { std::vector<char> Dep; - Instruction *Src = dyn_cast<Instruction>(*I); - Instruction *Des = dyn_cast<Instruction>(*J); - if (Src == Des) + Instruction *Src = cast<Instruction>(*I); + Instruction *Dst = cast<Instruction>(*J); + if (Src == Dst) continue; - if (isa<LoadInst>(Src) && isa<LoadInst>(Des)) + // Ignore Input dependencies. + if (isa<LoadInst>(Src) && isa<LoadInst>(Dst)) continue; - if (auto D = DI->depends(Src, Des, true)) { - DEBUG(dbgs() << "Found Dependency between Src=" << Src << " Des=" << Des - << "\n"); - if (D->isFlow()) { - // TODO: Handle Flow dependence.Check if it is sufficient to populate - // the Dependence Matrix with the direction reversed. - DEBUG(dbgs() << "Flow dependence not handled"); - return false; - } - if (D->isAnti()) { - DEBUG(dbgs() << "Found Anti dependence \n"); - unsigned Levels = D->getLevels(); - char Direction; - for (unsigned II = 1; II <= Levels; ++II) { - const SCEV *Distance = D->getDistance(II); - const SCEVConstant *SCEVConst = - dyn_cast_or_null<SCEVConstant>(Distance); - if (SCEVConst) { - const ConstantInt *CI = SCEVConst->getValue(); - if (CI->isNegative()) - Direction = '<'; - else if (CI->isZero()) - Direction = '='; - else - Direction = '>'; - Dep.push_back(Direction); - } else if (D->isScalar(II)) { - Direction = 'S'; - Dep.push_back(Direction); - } else { - unsigned Dir = D->getDirection(II); - if (Dir == Dependence::DVEntry::LT || - Dir == Dependence::DVEntry::LE) - Direction = '<'; - else if (Dir == Dependence::DVEntry::GT || - Dir == Dependence::DVEntry::GE) - Direction = '>'; - else if (Dir == Dependence::DVEntry::EQ) - Direction = '='; - else - Direction = '*'; - Dep.push_back(Direction); - } - } - while (Dep.size() != Level) { - Dep.push_back('I'); + // Track Output, Flow, and Anti dependencies. + if (auto D = DI->depends(Src, Dst, true)) { + assert(D->isOrdered() && "Expected an output, flow or anti dep."); + DEBUG(StringRef DepType = + D->isFlow() ? "flow" : D->isAnti() ? "anti" : "output"; + dbgs() << "Found " << DepType + << " dependency between Src and Dst\n" + << " Src:" << *Src << "\n Dst:" << *Dst << '\n'); + unsigned Levels = D->getLevels(); + char Direction; + for (unsigned II = 1; II <= Levels; ++II) { + const SCEV *Distance = D->getDistance(II); + const SCEVConstant *SCEVConst = + dyn_cast_or_null<SCEVConstant>(Distance); + if (SCEVConst) { + const ConstantInt *CI = SCEVConst->getValue(); + if (CI->isNegative()) + Direction = '<'; + else if (CI->isZero()) + Direction = '='; + else + Direction = '>'; + Dep.push_back(Direction); + } else if (D->isScalar(II)) { + Direction = 'S'; + Dep.push_back(Direction); + } else { + unsigned Dir = D->getDirection(II); + if (Dir == Dependence::DVEntry::LT || + Dir == Dependence::DVEntry::LE) + Direction = '<'; + else if (Dir == Dependence::DVEntry::GT || + Dir == Dependence::DVEntry::GE) + Direction = '>'; + else if (Dir == Dependence::DVEntry::EQ) + Direction = '='; + else + Direction = '*'; + Dep.push_back(Direction); } + } + while (Dep.size() != Level) { + Dep.push_back('I'); + } - DepMatrix.push_back(Dep); - if (DepMatrix.size() > MaxMemInstrCount) { - DEBUG(dbgs() << "Cannot handle more than " << MaxMemInstrCount - << " dependencies inside loop\n"); - return false; - } + DepMatrix.push_back(Dep); + if (DepMatrix.size() > MaxMemInstrCount) { + DEBUG(dbgs() << "Cannot handle more than " << MaxMemInstrCount + << " dependencies inside loop\n"); + return false; } } } @@ -183,8 +177,8 @@ static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, // A loop is moved from index 'from' to an index 'to'. Update the Dependence // matrix by exchanging the two columns. -static void interChangeDepedencies(CharMatrix &DepMatrix, unsigned FromIndx, - unsigned ToIndx) { +static void interChangeDependencies(CharMatrix &DepMatrix, unsigned FromIndx, + unsigned ToIndx) { unsigned numRows = DepMatrix.size(); for (unsigned i = 0; i < numRows; ++i) { char TmpVal = DepMatrix[i][ToIndx]; @@ -211,7 +205,7 @@ static bool isOuterMostDepPositive(CharMatrix &DepMatrix, unsigned Row, static bool containsNoDependence(CharMatrix &DepMatrix, unsigned Row, unsigned Column) { for (unsigned i = 0; i < Column; ++i) { - if (DepMatrix[Row][i] != '=' || DepMatrix[Row][i] != 'S' || + if (DepMatrix[Row][i] != '=' && DepMatrix[Row][i] != 'S' && DepMatrix[Row][i] != 'I') return false; } @@ -255,9 +249,8 @@ static bool validDepInterchange(CharMatrix &DepMatrix, unsigned Row, // Checks if it is legal to interchange 2 loops. // [Theorem] A permutation of the loops in a perfect nest is legal if and only -// if -// the direction matrix, after the same permutation is applied to its columns, -// has no ">" direction as the leftmost non-"=" direction in any row. +// if the direction matrix, after the same permutation is applied to its +// columns, has no ">" direction as the leftmost non-"=" direction in any row. static bool isLegalToInterChangeLoops(CharMatrix &DepMatrix, unsigned InnerLoopId, unsigned OuterLoopId) { @@ -269,8 +262,7 @@ static bool isLegalToInterChangeLoops(CharMatrix &DepMatrix, char OuterDep = DepMatrix[Row][OuterLoopId]; if (InnerDep == '*' || OuterDep == '*') return false; - else if (!validDepInterchange(DepMatrix, Row, OuterLoopId, InnerDep, - OuterDep)) + if (!validDepInterchange(DepMatrix, Row, OuterLoopId, InnerDep, OuterDep)) return false; } return true; @@ -278,7 +270,9 @@ static bool isLegalToInterChangeLoops(CharMatrix &DepMatrix, static void populateWorklist(Loop &L, SmallVector<LoopVector, 8> &V) { - DEBUG(dbgs() << "Calling populateWorklist called\n"); + DEBUG(dbgs() << "Calling populateWorklist on Func: " + << L.getHeader()->getParent()->getName() << " Loop: %" + << L.getHeader()->getName() << '\n'); LoopVector LoopList; Loop *CurrentLoop = &L; const std::vector<Loop *> *Vec = &CurrentLoop->getSubLoops(); @@ -315,8 +309,7 @@ static PHINode *getInductionVariable(Loop *L, ScalarEvolution *SE) { if (!AddRec || !AddRec->isAffine()) continue; const SCEV *Step = AddRec->getStepRecurrence(*SE); - const SCEVConstant *C = dyn_cast<SCEVConstant>(Step); - if (!C) + if (!isa<SCEVConstant>(Step)) continue; // Found the induction variable. // FIXME: Handle loops with more than one induction variable. Note that, @@ -474,7 +467,7 @@ struct LoopInterchange : public FunctionPass { for (Loop *L : LoopList) { const SCEV *ExitCountOuter = SE->getBackedgeTakenCount(L); if (ExitCountOuter == SE->getCouldNotCompute()) { - DEBUG(dbgs() << "Couldn't compute Backedge count\n"); + DEBUG(dbgs() << "Couldn't compute backedge count\n"); return false; } if (L->getNumBackEdges() != 1) { @@ -482,7 +475,7 @@ struct LoopInterchange : public FunctionPass { return false; } if (!L->getExitingBlock()) { - DEBUG(dbgs() << "Loop Doesn't have unique exit block\n"); + DEBUG(dbgs() << "Loop doesn't have unique exit block\n"); return false; } } @@ -498,27 +491,32 @@ struct LoopInterchange : public FunctionPass { bool processLoopList(LoopVector LoopList, Function &F) { bool Changed = false; - CharMatrix DependencyMatrix; - if (LoopList.size() < 2) { + unsigned LoopNestDepth = LoopList.size(); + if (LoopNestDepth < 2) { DEBUG(dbgs() << "Loop doesn't contain minimum nesting level.\n"); return false; } + if (LoopNestDepth > MaxLoopNestDepth) { + DEBUG(dbgs() << "Cannot handle loops of depth greater than " + << MaxLoopNestDepth << "\n"); + return false; + } if (!isComputableLoopNest(LoopList)) { - DEBUG(dbgs() << "Not vaild loop candidate for interchange\n"); + DEBUG(dbgs() << "Not valid loop candidate for interchange\n"); return false; } - Loop *OuterMostLoop = *(LoopList.begin()); - DEBUG(dbgs() << "Processing LoopList of size = " << LoopList.size() - << "\n"); + DEBUG(dbgs() << "Processing LoopList of size = " << LoopNestDepth << "\n"); - if (!populateDependencyMatrix(DependencyMatrix, LoopList.size(), + CharMatrix DependencyMatrix; + Loop *OuterMostLoop = *(LoopList.begin()); + if (!populateDependencyMatrix(DependencyMatrix, LoopNestDepth, OuterMostLoop, DI)) { - DEBUG(dbgs() << "Populating Dependency matrix failed\n"); + DEBUG(dbgs() << "Populating dependency matrix failed\n"); return false; } #ifdef DUMP_DEP_MATRICIES - DEBUG(dbgs() << "Dependence before inter change \n"); + DEBUG(dbgs() << "Dependence before interchange\n"); printDepMatrix(DependencyMatrix); #endif @@ -556,10 +554,10 @@ struct LoopInterchange : public FunctionPass { std::swap(LoopList[i - 1], LoopList[i]); // Update the DependencyMatrix - interChangeDepedencies(DependencyMatrix, i, i - 1); + interChangeDependencies(DependencyMatrix, i, i - 1); DT->recalculate(F); #ifdef DUMP_DEP_MATRICIES - DEBUG(dbgs() << "Dependence after inter change \n"); + DEBUG(dbgs() << "Dependence after interchange\n"); printDepMatrix(DependencyMatrix); #endif Changed |= Interchanged; @@ -571,7 +569,7 @@ struct LoopInterchange : public FunctionPass { unsigned OuterLoopId, BasicBlock *LoopNestExit, std::vector<std::vector<char>> &DependencyMatrix) { - DEBUG(dbgs() << "Processing Innder Loop Id = " << InnerLoopId + DEBUG(dbgs() << "Processing Inner Loop Id = " << InnerLoopId << " and OuterLoopId = " << OuterLoopId << "\n"); Loop *InnerLoop = LoopList[InnerLoopId]; Loop *OuterLoop = LoopList[OuterLoopId]; @@ -585,7 +583,7 @@ struct LoopInterchange : public FunctionPass { DEBUG(dbgs() << "Loops are legal to interchange\n"); LoopInterchangeProfitability LIP(OuterLoop, InnerLoop, SE); if (!LIP.isProfitable(InnerLoopId, OuterLoopId, DependencyMatrix)) { - DEBUG(dbgs() << "Interchanging Loops not profitable\n"); + DEBUG(dbgs() << "Interchanging loops not profitable\n"); return false; } @@ -599,8 +597,8 @@ struct LoopInterchange : public FunctionPass { } // end of namespace bool LoopInterchangeLegality::areAllUsesReductions(Instruction *Ins, Loop *L) { - return !std::any_of(Ins->user_begin(), Ins->user_end(), [=](User *U) -> bool { - PHINode *UserIns = dyn_cast<PHINode>(U); + return none_of(Ins->users(), [=](User *U) -> bool { + auto *UserIns = dyn_cast<PHINode>(U); RecurrenceDescriptor RD; return !UserIns || !RecurrenceDescriptor::isReductionPHI(UserIns, L, RD); }); @@ -626,8 +624,7 @@ bool LoopInterchangeLegality::containsUnsafeInstructionsInLatch( // Stores corresponding to reductions are safe while concluding if tightly // nested. if (StoreInst *L = dyn_cast<StoreInst>(I)) { - PHINode *PHI = dyn_cast<PHINode>(L->getOperand(0)); - if (!PHI) + if (!isa<PHINode>(L->getOperand(0))) return true; } else if (I->mayHaveSideEffects() || I->mayReadFromMemory()) return true; @@ -640,30 +637,30 @@ bool LoopInterchangeLegality::tightlyNested(Loop *OuterLoop, Loop *InnerLoop) { BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); - DEBUG(dbgs() << "Checking if Loops are Tightly Nested\n"); + DEBUG(dbgs() << "Checking if loops are tightly nested\n"); // A perfectly nested loop will not have any branch in between the outer and // inner block i.e. outer header will branch to either inner preheader and // outerloop latch. - BranchInst *outerLoopHeaderBI = + BranchInst *OuterLoopHeaderBI = dyn_cast<BranchInst>(OuterLoopHeader->getTerminator()); - if (!outerLoopHeaderBI) + if (!OuterLoopHeaderBI) return false; - unsigned num = outerLoopHeaderBI->getNumSuccessors(); - for (unsigned i = 0; i < num; i++) { - if (outerLoopHeaderBI->getSuccessor(i) != InnerLoopPreHeader && - outerLoopHeaderBI->getSuccessor(i) != OuterLoopLatch) + + for (unsigned i = 0, e = OuterLoopHeaderBI->getNumSuccessors(); i < e; ++i) { + if (OuterLoopHeaderBI->getSuccessor(i) != InnerLoopPreHeader && + OuterLoopHeaderBI->getSuccessor(i) != OuterLoopLatch) return false; } - DEBUG(dbgs() << "Checking instructions in Loop header and Loop latch \n"); + DEBUG(dbgs() << "Checking instructions in Loop header and Loop latch\n"); // We do not have any basic block in between now make sure the outer header // and outer loop latch doesn't contain any unsafe instructions. if (containsUnsafeInstructionsInHeader(OuterLoopHeader) || containsUnsafeInstructionsInLatch(OuterLoopLatch)) return false; - DEBUG(dbgs() << "Loops are perfectly nested \n"); + DEBUG(dbgs() << "Loops are perfectly nested\n"); // We have a perfect loop nest. return true; } @@ -703,7 +700,7 @@ bool LoopInterchangeLegality::findInductionAndReductions( RecurrenceDescriptor RD; InductionDescriptor ID; PHINode *PHI = cast<PHINode>(I); - if (InductionDescriptor::isInductionPHI(PHI, SE, ID)) + if (InductionDescriptor::isInductionPHI(PHI, L, SE, ID)) Inductions.push_back(PHI); else if (RecurrenceDescriptor::isReductionPHI(PHI, L, RD)) Reductions.push_back(PHI); @@ -852,8 +849,8 @@ bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId, if (!isLegalToInterChangeLoops(DepMatrix, InnerLoopId, OuterLoopId)) { DEBUG(dbgs() << "Failed interchange InnerLoopId = " << InnerLoopId - << "and OuterLoopId = " << OuterLoopId - << "due to dependence\n"); + << " and OuterLoopId = " << OuterLoopId + << " due to dependence\n"); return false; } @@ -946,9 +943,9 @@ int LoopInterchangeProfitability::getInstrOrderCost() { return GoodOrder - BadOrder; } -static bool isProfitabileForVectorization(unsigned InnerLoopId, - unsigned OuterLoopId, - CharMatrix &DepMatrix) { +static bool isProfitableForVectorization(unsigned InnerLoopId, + unsigned OuterLoopId, + CharMatrix &DepMatrix) { // TODO: Improve this heuristic to catch more cases. // If the inner loop is loop independent or doesn't carry any dependency it is // profitable to move this to outer position. @@ -977,16 +974,15 @@ bool LoopInterchangeProfitability::isProfitable(unsigned InnerLoopId, // This is rough cost estimation algorithm. It counts the good and bad order // of induction variables in the instruction and allows reordering if number // of bad orders is more than good. - int Cost = 0; - Cost += getInstrOrderCost(); + int Cost = getInstrOrderCost(); DEBUG(dbgs() << "Cost = " << Cost << "\n"); - if (Cost < 0) + if (Cost < -LoopInterchangeCostThreshold) return true; // It is not profitable as per current cache profitability model. But check if // we can move this loop outside to improve parallelism. bool ImprovesPar = - isProfitabileForVectorization(InnerLoopId, OuterLoopId, DepMatrix); + isProfitableForVectorization(InnerLoopId, OuterLoopId, DepMatrix); return ImprovesPar; } @@ -1022,8 +1018,6 @@ void LoopInterchangeTransform::restructureLoops(Loop *InnerLoop, } bool LoopInterchangeTransform::transform() { - - DEBUG(dbgs() << "transform\n"); bool Transformed = false; Instruction *InnerIndexVar; @@ -1046,16 +1040,16 @@ bool LoopInterchangeTransform::transform() { // incremented/decremented. // TODO: This splitting logic may not work always. Fix this. splitInnerLoopLatch(InnerIndexVar); - DEBUG(dbgs() << "splitInnerLoopLatch Done\n"); + DEBUG(dbgs() << "splitInnerLoopLatch done\n"); // Splits the inner loops phi nodes out into a separate basic block. splitInnerLoopHeader(); - DEBUG(dbgs() << "splitInnerLoopHeader Done\n"); + DEBUG(dbgs() << "splitInnerLoopHeader done\n"); } Transformed |= adjustLoopLinks(); if (!Transformed) { - DEBUG(dbgs() << "adjustLoopLinks Failed\n"); + DEBUG(dbgs() << "adjustLoopLinks failed\n"); return false; } @@ -1099,7 +1093,7 @@ void LoopInterchangeTransform::splitInnerLoopHeader() { } DEBUG(dbgs() << "Output of splitInnerLoopHeader InnerLoopHeaderSucc & " - "InnerLoopHeader \n"); + "InnerLoopHeader\n"); } /// \brief Move all instructions except the terminator from FromBB right before diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp index f29228c..8fb5801 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -20,17 +20,37 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/LoopVersioning.h" #include <forward_list> +#include <cassert> +#include <algorithm> +#include <set> +#include <tuple> +#include <utility> #define LLE_OPTION "loop-load-elim" #define DEBUG_TYPE LLE_OPTION @@ -47,7 +67,6 @@ static cl::opt<unsigned> LoadElimSCEVCheckThreshold( cl::desc("The maximum number of SCEV checks allowed for Loop " "Load Elimination")); - STATISTIC(NumLoopLoadEliminted, "Number of loads eliminated by LLE"); namespace { @@ -113,10 +132,9 @@ bool doesStoreDominatesAllLatches(BasicBlock *StoreBlock, Loop *L, DominatorTree *DT) { SmallVector<BasicBlock *, 8> Latches; L->getLoopLatches(Latches); - return std::all_of(Latches.begin(), Latches.end(), - [&](const BasicBlock *Latch) { - return DT->dominates(StoreBlock, Latch); - }); + return llvm::all_of(Latches, [&](const BasicBlock *Latch) { + return DT->dominates(StoreBlock, Latch); + }); } /// \brief Return true if the load is not executed on all paths in the loop. @@ -348,7 +366,7 @@ public: // Collect the pointers of the candidate loads. // FIXME: SmallSet does not work with std::inserter. std::set<Value *> CandLoadPtrs; - std::transform(Candidates.begin(), Candidates.end(), + transform(Candidates, std::inserter(CandLoadPtrs, CandLoadPtrs.begin()), std::mem_fn(&StoreToLoadForwardingCandidate::getLoadPtr)); @@ -397,7 +415,9 @@ public: Value *InitialPtr = SEE.expandCodeFor(PtrSCEV->getStart(), Ptr->getType(), PH->getTerminator()); Value *Initial = - new LoadInst(InitialPtr, "load_initial", PH->getTerminator()); + new LoadInst(InitialPtr, "load_initial", /* isVolatile */ false, + Cand.Load->getAlignment(), PH->getTerminator()); + PHINode *PHI = PHINode::Create(Initial->getType(), 2, "store_forwarded", &L->getHeader()->front()); PHI->addIncoming(Initial, PH); @@ -499,6 +519,11 @@ public: return false; } + if (!L->isLoopSimplifyForm()) { + DEBUG(dbgs() << "Loop is not is loop-simplify form"); + return false; + } + // Point of no-return, start the transformation. First, version the loop // if necessary. @@ -581,11 +606,13 @@ public: AU.addRequired<ScalarEvolutionWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); } static char ID; }; -} + +} // end anonymous namespace char LoopLoadElimination::ID; static const char LLE_name[] = "Loop Load Elimination"; @@ -599,7 +626,9 @@ INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_END(LoopLoadElimination, LLE_OPTION, LLE_name, false, false) namespace llvm { + FunctionPass *createLoopLoadEliminationPass() { return new LoopLoadElimination(); } -} + +} // end namespace llvm diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopPassManager.cpp new file mode 100644 index 0000000..028f4bb --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopPassManager.cpp @@ -0,0 +1,85 @@ +//===- LoopPassManager.cpp - Loop pass management -------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Analysis/LoopInfo.h" + +using namespace llvm; + +// Explicit template instantiations and specialization defininitions for core +// template typedefs. +namespace llvm { +template class PassManager<Loop, LoopAnalysisManager, + LoopStandardAnalysisResults &, LPMUpdater &>; + +/// Explicitly specialize the pass manager's run method to handle loop nest +/// structure updates. +template <> +PreservedAnalyses +PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &, + LPMUpdater &>::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, LPMUpdater &U) { + PreservedAnalyses PA = PreservedAnalyses::all(); + + if (DebugLogging) + dbgs() << "Starting Loop pass manager run.\n"; + + for (auto &Pass : Passes) { + if (DebugLogging) + dbgs() << "Running pass: " << Pass->name() << " on " << L; + + PreservedAnalyses PassPA = Pass->run(L, AM, AR, U); + + // If the loop was deleted, abort the run and return to the outer walk. + if (U.skipCurrentLoop()) { + PA.intersect(std::move(PassPA)); + break; + } + + // Update the analysis manager as each pass runs and potentially + // invalidates analyses. + AM.invalidate(L, PassPA); + + // Finally, we intersect the final preserved analyses to compute the + // aggregate preserved set for this pass manager. + PA.intersect(std::move(PassPA)); + + // FIXME: Historically, the pass managers all called the LLVM context's + // yield function here. We don't have a generic way to acquire the + // context and it isn't yet clear what the right pattern is for yielding + // in the new pass manager so it is currently omitted. + // ...getContext().yield(); + } + + // Invalidation for the current loop should be handled above, and other loop + // analysis results shouldn't be impacted by runs over this loop. Therefore, + // the remaining analysis results in the AnalysisManager are preserved. We + // mark this with a set so that we don't need to inspect each one + // individually. + // FIXME: This isn't correct! This loop and all nested loops' analyses should + // be preserved, but unrolling should invalidate the parent loop's analyses. + PA.preserveSet<AllAnalysesOn<Loop>>(); + + if (DebugLogging) + dbgs() << "Finished Loop pass manager run.\n"; + + return PA; +} +} + +PrintLoopPass::PrintLoopPass() : OS(dbgs()) {} +PrintLoopPass::PrintLoopPass(raw_ostream &OS, const std::string &Banner) + : OS(OS), Banner(Banner) {} + +PreservedAnalyses PrintLoopPass::run(Loop &L, LoopAnalysisManager &, + LoopStandardAnalysisResults &, + LPMUpdater &) { + printLoop(L, OS, Banner); + return PreservedAnalyses::all(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp index d2f1b66..86058fe 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp @@ -371,11 +371,12 @@ namespace { protected: typedef MapVector<Instruction*, BitVector> UsesTy; - bool findRootsRecursive(Instruction *IVU, + void findRootsRecursive(Instruction *IVU, SmallInstructionSet SubsumedInsts); bool findRootsBase(Instruction *IVU, SmallInstructionSet SubsumedInsts); bool collectPossibleRoots(Instruction *Base, std::map<int64_t,Instruction*> &Roots); + bool validateRootSet(DAGRootSet &DRS); bool collectUsedInstructions(SmallInstructionSet &PossibleRedSet); void collectInLoopUserSet(const SmallInstructionVector &Roots, @@ -739,11 +740,11 @@ void LoopReroll::DAGRootTracker::collectInLoopUserSet( collectInLoopUserSet(Root, Exclude, Final, Users); } -static bool isSimpleLoadStore(Instruction *I) { +static bool isUnorderedLoadStore(Instruction *I) { if (LoadInst *LI = dyn_cast<LoadInst>(I)) - return LI->isSimple(); + return LI->isUnordered(); if (StoreInst *SI = dyn_cast<StoreInst>(I)) - return SI->isSimple(); + return SI->isUnordered(); if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) return !MI->isVolatile(); return false; @@ -827,7 +828,8 @@ collectPossibleRoots(Instruction *Base, std::map<int64_t,Instruction*> &Roots) { Roots[V] = cast<Instruction>(I); } - if (Roots.empty()) + // Make sure we have at least two roots. + if (Roots.empty() || (Roots.size() == 1 && BaseUsers.empty())) return false; // If we found non-loop-inc, non-root users of Base, assume they are @@ -861,40 +863,61 @@ collectPossibleRoots(Instruction *Base, std::map<int64_t,Instruction*> &Roots) { return true; } -bool LoopReroll::DAGRootTracker:: +void LoopReroll::DAGRootTracker:: findRootsRecursive(Instruction *I, SmallInstructionSet SubsumedInsts) { // Does the user look like it could be part of a root set? // All its users must be simple arithmetic ops. if (I->getNumUses() > IL_MaxRerollIterations) - return false; + return; - if ((I->getOpcode() == Instruction::Mul || - I->getOpcode() == Instruction::PHI) && - I != IV && - findRootsBase(I, SubsumedInsts)) - return true; + if (I != IV && findRootsBase(I, SubsumedInsts)) + return; SubsumedInsts.insert(I); for (User *V : I->users()) { - Instruction *I = dyn_cast<Instruction>(V); - if (std::find(LoopIncs.begin(), LoopIncs.end(), I) != LoopIncs.end()) + Instruction *I = cast<Instruction>(V); + if (is_contained(LoopIncs, I)) continue; - if (!I || !isSimpleArithmeticOp(I) || - !findRootsRecursive(I, SubsumedInsts)) - return false; + if (!isSimpleArithmeticOp(I)) + continue; + + // The recursive call makes a copy of SubsumedInsts. + findRootsRecursive(I, SubsumedInsts); } +} + +bool LoopReroll::DAGRootTracker::validateRootSet(DAGRootSet &DRS) { + if (DRS.Roots.empty()) + return false; + + // Consider a DAGRootSet with N-1 roots (so N different values including + // BaseInst). + // Define d = Roots[0] - BaseInst, which should be the same as + // Roots[I] - Roots[I-1] for all I in [1..N). + // Define D = BaseInst@J - BaseInst@J-1, where "@J" means the value at the + // loop iteration J. + // + // Now, For the loop iterations to be consecutive: + // D = d * N + const auto *ADR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(DRS.BaseInst)); + if (!ADR) + return false; + unsigned N = DRS.Roots.size() + 1; + const SCEV *StepSCEV = SE->getMinusSCEV(SE->getSCEV(DRS.Roots[0]), ADR); + const SCEV *ScaleSCEV = SE->getConstant(StepSCEV->getType(), N); + if (ADR->getStepRecurrence(*SE) != SE->getMulExpr(StepSCEV, ScaleSCEV)) + return false; + return true; } bool LoopReroll::DAGRootTracker:: findRootsBase(Instruction *IVU, SmallInstructionSet SubsumedInsts) { - - // The base instruction needs to be a multiply so - // that we can erase it. - if (IVU->getOpcode() != Instruction::Mul && - IVU->getOpcode() != Instruction::PHI) + // The base of a RootSet must be an AddRec, so it can be erased. + const auto *IVU_ADR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(IVU)); + if (!IVU_ADR || IVU_ADR->getLoop() != L) return false; std::map<int64_t, Instruction*> V; @@ -910,6 +933,8 @@ findRootsBase(Instruction *IVU, SmallInstructionSet SubsumedInsts) { DAGRootSet DRS; DRS.BaseInst = nullptr; + SmallVector<DAGRootSet, 16> PotentialRootSets; + for (auto &KV : V) { if (!DRS.BaseInst) { DRS.BaseInst = KV.second; @@ -920,13 +945,22 @@ findRootsBase(Instruction *IVU, SmallInstructionSet SubsumedInsts) { DRS.Roots.push_back(KV.second); } else { // Linear sequence terminated. - RootSets.push_back(DRS); + if (!validateRootSet(DRS)) + return false; + + // Construct a new DAGRootSet with the next sequence. + PotentialRootSets.push_back(DRS); DRS.BaseInst = KV.second; - DRS.SubsumedInsts = SubsumedInsts; DRS.Roots.clear(); } } - RootSets.push_back(DRS); + + if (!validateRootSet(DRS)) + return false; + + PotentialRootSets.push_back(DRS); + + RootSets.append(PotentialRootSets.begin(), PotentialRootSets.end()); return true; } @@ -940,8 +974,7 @@ bool LoopReroll::DAGRootTracker::findRoots() { if (isLoopIncrement(IVU, IV)) LoopIncs.push_back(cast<Instruction>(IVU)); } - if (!findRootsRecursive(IV, SmallInstructionSet())) - return false; + findRootsRecursive(IV, SmallInstructionSet()); LoopIncs.push_back(IV); } else { if (!findRootsBase(IV, SmallInstructionSet())) @@ -961,31 +994,6 @@ bool LoopReroll::DAGRootTracker::findRoots() { } } - // And ensure all loop iterations are consecutive. We rely on std::map - // providing ordered traversal. - for (auto &V : RootSets) { - const auto *ADR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(V.BaseInst)); - if (!ADR) - return false; - - // Consider a DAGRootSet with N-1 roots (so N different values including - // BaseInst). - // Define d = Roots[0] - BaseInst, which should be the same as - // Roots[I] - Roots[I-1] for all I in [1..N). - // Define D = BaseInst@J - BaseInst@J-1, where "@J" means the value at the - // loop iteration J. - // - // Now, For the loop iterations to be consecutive: - // D = d * N - - unsigned N = V.Roots.size() + 1; - const SCEV *StepSCEV = SE->getMinusSCEV(SE->getSCEV(V.Roots[0]), ADR); - const SCEV *ScaleSCEV = SE->getConstant(StepSCEV->getType(), N); - if (ADR->getStepRecurrence(*SE) != SE->getMulExpr(StepSCEV, ScaleSCEV)) { - DEBUG(dbgs() << "LRR: Aborting because iterations are not consecutive\n"); - return false; - } - } Scale = RootSets[0].Roots.size() + 1; if (Scale > IL_MaxRerollIterations) { @@ -1088,7 +1096,7 @@ bool LoopReroll::DAGRootTracker::isBaseInst(Instruction *I) { bool LoopReroll::DAGRootTracker::isRootInst(Instruction *I) { for (auto &DRS : RootSets) { - if (std::find(DRS.Roots.begin(), DRS.Roots.end(), I) != DRS.Roots.end()) + if (is_contained(DRS.Roots, I)) return true; } return false; @@ -1283,7 +1291,7 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { // which while a valid (somewhat arbitrary) micro-optimization, is // needed because otherwise isSafeToSpeculativelyExecute returns // false on PHI nodes. - if (!isa<PHINode>(I) && !isSimpleLoadStore(I) && + if (!isa<PHINode>(I) && !isUnorderedLoadStore(I) && !isSafeToSpeculativelyExecute(I)) // Intervening instructions cause side effects. FutureSideEffects = true; @@ -1313,10 +1321,10 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { // If we've past an instruction from a future iteration that may have // side effects, and this instruction might also, then we can't reorder // them, and this matching fails. As an exception, we allow the alias - // set tracker to handle regular (simple) load/store dependencies. - if (FutureSideEffects && ((!isSimpleLoadStore(BaseInst) && + // set tracker to handle regular (unordered) load/store dependencies. + if (FutureSideEffects && ((!isUnorderedLoadStore(BaseInst) && !isSafeToSpeculativelyExecute(BaseInst)) || - (!isSimpleLoadStore(RootInst) && + (!isUnorderedLoadStore(RootInst) && !isSafeToSpeculativelyExecute(RootInst)))) { DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst << " vs. " << *RootInst << @@ -1412,13 +1420,12 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { void LoopReroll::DAGRootTracker::replace(const SCEV *IterCount) { BasicBlock *Header = L->getHeader(); // Remove instructions associated with non-base iterations. - for (BasicBlock::reverse_iterator J = Header->rbegin(); - J != Header->rend();) { + for (BasicBlock::reverse_iterator J = Header->rbegin(), JE = Header->rend(); + J != JE;) { unsigned I = Uses[&*J].find_first(); if (I > 0 && I < IL_All) { - Instruction *D = &*J; - DEBUG(dbgs() << "LRR: removing: " << *D << "\n"); - D->eraseFromParent(); + DEBUG(dbgs() << "LRR: removing: " << *J << "\n"); + J++->eraseFromParent(); continue; } @@ -1499,8 +1506,8 @@ void LoopReroll::DAGRootTracker::replaceIV(Instruction *Inst, { // Limit the lifetime of SCEVExpander. const DataLayout &DL = Header->getModule()->getDataLayout(); SCEVExpander Expander(*SE, DL, "reroll"); - Value *NewIV = - Expander.expandCodeFor(NewIVSCEV, InstIV->getType(), &Header->front()); + Value *NewIV = Expander.expandCodeFor(NewIVSCEV, Inst->getType(), + Header->getFirstNonPHIOrDbg()); for (auto &KV : Uses) if (KV.second.find_first() == 0) diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopRotation.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopRotation.cpp index 7a06a25..cc83069 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopRotation.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopRotation.cpp @@ -14,13 +14,12 @@ #include "llvm/Transforms/Scalar/LoopRotation.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/CodeMetrics.h" -#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/LoopPassManager.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -34,6 +33,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -326,6 +326,10 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // Otherwise, stick the new instruction into the new block! C->setName(Inst->getName()); C->insertBefore(LoopEntryBranch); + + if (auto *II = dyn_cast<IntrinsicInst>(C)) + if (II->getIntrinsicID() == Intrinsic::assume) + AC->registerAssumption(II); } } @@ -501,7 +505,8 @@ static bool shouldSpeculateInstrs(BasicBlock::iterator Begin, // GEPs are cheap if all indices are constant. if (!cast<GEPOperator>(I)->hasAllConstantIndices()) return false; - // fall-thru to increment case + // fall-thru to increment case + LLVM_FALLTHROUGH; case Instruction::Add: case Instruction::Sub: case Instruction::And: @@ -617,21 +622,14 @@ bool LoopRotate::processLoop(Loop *L) { return MadeChange; } -LoopRotatePass::LoopRotatePass() {} - -PreservedAnalyses LoopRotatePass::run(Loop &L, AnalysisManager<Loop> &AM) { - auto &FAM = AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); - Function *F = L.getHeader()->getParent(); - - auto *LI = FAM.getCachedResult<LoopAnalysis>(*F); - const auto *TTI = FAM.getCachedResult<TargetIRAnalysis>(*F); - auto *AC = FAM.getCachedResult<AssumptionAnalysis>(*F); - assert((LI && TTI && AC) && "Analyses for loop rotation not available"); +LoopRotatePass::LoopRotatePass(bool EnableHeaderDuplication) + : EnableHeaderDuplication(EnableHeaderDuplication) {} - // Optional analyses. - auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(*F); - auto *SE = FAM.getCachedResult<ScalarEvolutionAnalysis>(*F); - LoopRotate LR(DefaultRotationThreshold, LI, TTI, AC, DT, SE); +PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + int Threshold = EnableHeaderDuplication ? DefaultRotationThreshold : 0; + LoopRotate LR(Threshold, &AR.LI, &AR.TTI, &AR.AC, &AR.DT, &AR.SE); bool Changed = LR.processLoop(&L); if (!Changed) diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp index ec22793..1606121 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -18,18 +18,18 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/DependenceAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/LoopPassManager.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; @@ -64,16 +64,10 @@ static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI) { return Changed; } -PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, AnalysisManager<Loop> &AM) { - const auto &FAM = - AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); - Function *F = L.getHeader()->getParent(); - - auto *LI = FAM.getCachedResult<LoopAnalysis>(*F); - auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(*F); - assert((LI && DT) && "Analyses for LoopSimplifyCFG not available"); - - if (!simplifyLoopCFG(L, *DT, *LI)) +PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + if (!simplifyLoopCFG(L, AR.DT, AR.LI)) return PreservedAnalyses::all(); return getLoopPassPreservedAnalyses(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopSink.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopSink.cpp new file mode 100644 index 0000000..f3f4152 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopSink.cpp @@ -0,0 +1,335 @@ +//===-- LoopSink.cpp - Loop Sink Pass ------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass does the inverse transformation of what LICM does. +// It traverses all of the instructions in the loop's preheader and sinks +// them to the loop body where frequency is lower than the loop's preheader. +// This pass is a reverse-transformation of LICM. It differs from the Sink +// pass in the following ways: +// +// * It only handles sinking of instructions from the loop's preheader to the +// loop's body +// * It uses alias set tracker to get more accurate alias info +// * It uses block frequency info to find the optimal sinking locations +// +// Overall algorithm: +// +// For I in Preheader: +// InsertBBs = BBs that uses I +// For BB in sorted(LoopBBs): +// DomBBs = BBs in InsertBBs that are dominated by BB +// if freq(DomBBs) > freq(BB) +// InsertBBs = UseBBs - DomBBs + BB +// For BB in InsertBBs: +// Insert I at BB's beginning +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +using namespace llvm; + +#define DEBUG_TYPE "loopsink" + +STATISTIC(NumLoopSunk, "Number of instructions sunk into loop"); +STATISTIC(NumLoopSunkCloned, "Number of cloned instructions sunk into loop"); + +static cl::opt<unsigned> SinkFrequencyPercentThreshold( + "sink-freq-percent-threshold", cl::Hidden, cl::init(90), + cl::desc("Do not sink instructions that require cloning unless they " + "execute less than this percent of the time.")); + +static cl::opt<unsigned> MaxNumberOfUseBBsForSinking( + "max-uses-for-sinking", cl::Hidden, cl::init(30), + cl::desc("Do not sink instructions that have too many uses.")); + +/// Return adjusted total frequency of \p BBs. +/// +/// * If there is only one BB, sinking instruction will not introduce code +/// size increase. Thus there is no need to adjust the frequency. +/// * If there are more than one BB, sinking would lead to code size increase. +/// In this case, we add some "tax" to the total frequency to make it harder +/// to sink. E.g. +/// Freq(Preheader) = 100 +/// Freq(BBs) = sum(50, 49) = 99 +/// Even if Freq(BBs) < Freq(Preheader), we will not sink from Preheade to +/// BBs as the difference is too small to justify the code size increase. +/// To model this, The adjusted Freq(BBs) will be: +/// AdjustedFreq(BBs) = 99 / SinkFrequencyPercentThreshold% +static BlockFrequency adjustedSumFreq(SmallPtrSetImpl<BasicBlock *> &BBs, + BlockFrequencyInfo &BFI) { + BlockFrequency T = 0; + for (BasicBlock *B : BBs) + T += BFI.getBlockFreq(B); + if (BBs.size() > 1) + T /= BranchProbability(SinkFrequencyPercentThreshold, 100); + return T; +} + +/// Return a set of basic blocks to insert sinked instructions. +/// +/// The returned set of basic blocks (BBsToSinkInto) should satisfy: +/// +/// * Inside the loop \p L +/// * For each UseBB in \p UseBBs, there is at least one BB in BBsToSinkInto +/// that domintates the UseBB +/// * Has minimum total frequency that is no greater than preheader frequency +/// +/// The purpose of the function is to find the optimal sinking points to +/// minimize execution cost, which is defined as "sum of frequency of +/// BBsToSinkInto". +/// As a result, the returned BBsToSinkInto needs to have minimum total +/// frequency. +/// Additionally, if the total frequency of BBsToSinkInto exceeds preheader +/// frequency, the optimal solution is not sinking (return empty set). +/// +/// \p ColdLoopBBs is used to help find the optimal sinking locations. +/// It stores a list of BBs that is: +/// +/// * Inside the loop \p L +/// * Has a frequency no larger than the loop's preheader +/// * Sorted by BB frequency +/// +/// The complexity of the function is O(UseBBs.size() * ColdLoopBBs.size()). +/// To avoid expensive computation, we cap the maximum UseBBs.size() in its +/// caller. +static SmallPtrSet<BasicBlock *, 2> +findBBsToSinkInto(const Loop &L, const SmallPtrSetImpl<BasicBlock *> &UseBBs, + const SmallVectorImpl<BasicBlock *> &ColdLoopBBs, + DominatorTree &DT, BlockFrequencyInfo &BFI) { + SmallPtrSet<BasicBlock *, 2> BBsToSinkInto; + if (UseBBs.size() == 0) + return BBsToSinkInto; + + BBsToSinkInto.insert(UseBBs.begin(), UseBBs.end()); + SmallPtrSet<BasicBlock *, 2> BBsDominatedByColdestBB; + + // For every iteration: + // * Pick the ColdestBB from ColdLoopBBs + // * Find the set BBsDominatedByColdestBB that satisfy: + // - BBsDominatedByColdestBB is a subset of BBsToSinkInto + // - Every BB in BBsDominatedByColdestBB is dominated by ColdestBB + // * If Freq(ColdestBB) < Freq(BBsDominatedByColdestBB), remove + // BBsDominatedByColdestBB from BBsToSinkInto, add ColdestBB to + // BBsToSinkInto + for (BasicBlock *ColdestBB : ColdLoopBBs) { + BBsDominatedByColdestBB.clear(); + for (BasicBlock *SinkedBB : BBsToSinkInto) + if (DT.dominates(ColdestBB, SinkedBB)) + BBsDominatedByColdestBB.insert(SinkedBB); + if (BBsDominatedByColdestBB.size() == 0) + continue; + if (adjustedSumFreq(BBsDominatedByColdestBB, BFI) > + BFI.getBlockFreq(ColdestBB)) { + for (BasicBlock *DominatedBB : BBsDominatedByColdestBB) { + BBsToSinkInto.erase(DominatedBB); + } + BBsToSinkInto.insert(ColdestBB); + } + } + + // If the total frequency of BBsToSinkInto is larger than preheader frequency, + // do not sink. + if (adjustedSumFreq(BBsToSinkInto, BFI) > + BFI.getBlockFreq(L.getLoopPreheader())) + BBsToSinkInto.clear(); + return BBsToSinkInto; +} + +// Sinks \p I from the loop \p L's preheader to its uses. Returns true if +// sinking is successful. +// \p LoopBlockNumber is used to sort the insertion blocks to ensure +// determinism. +static bool sinkInstruction(Loop &L, Instruction &I, + const SmallVectorImpl<BasicBlock *> &ColdLoopBBs, + const SmallDenseMap<BasicBlock *, int, 16> &LoopBlockNumber, + LoopInfo &LI, DominatorTree &DT, + BlockFrequencyInfo &BFI) { + // Compute the set of blocks in loop L which contain a use of I. + SmallPtrSet<BasicBlock *, 2> BBs; + for (auto &U : I.uses()) { + Instruction *UI = cast<Instruction>(U.getUser()); + // We cannot sink I to PHI-uses. + if (dyn_cast<PHINode>(UI)) + return false; + // We cannot sink I if it has uses outside of the loop. + if (!L.contains(LI.getLoopFor(UI->getParent()))) + return false; + BBs.insert(UI->getParent()); + } + + // findBBsToSinkInto is O(BBs.size() * ColdLoopBBs.size()). We cap the max + // BBs.size() to avoid expensive computation. + // FIXME: Handle code size growth for min_size and opt_size. + if (BBs.size() > MaxNumberOfUseBBsForSinking) + return false; + + // Find the set of BBs that we should insert a copy of I. + SmallPtrSet<BasicBlock *, 2> BBsToSinkInto = + findBBsToSinkInto(L, BBs, ColdLoopBBs, DT, BFI); + if (BBsToSinkInto.empty()) + return false; + + // Copy the final BBs into a vector and sort them using the total ordering + // of the loop block numbers as iterating the set doesn't give a useful + // order. No need to stable sort as the block numbers are a total ordering. + SmallVector<BasicBlock *, 2> SortedBBsToSinkInto; + SortedBBsToSinkInto.insert(SortedBBsToSinkInto.begin(), BBsToSinkInto.begin(), + BBsToSinkInto.end()); + std::sort(SortedBBsToSinkInto.begin(), SortedBBsToSinkInto.end(), + [&](BasicBlock *A, BasicBlock *B) { + return *LoopBlockNumber.find(A) < *LoopBlockNumber.find(B); + }); + + BasicBlock *MoveBB = *SortedBBsToSinkInto.begin(); + // FIXME: Optimize the efficiency for cloned value replacement. The current + // implementation is O(SortedBBsToSinkInto.size() * I.num_uses()). + for (BasicBlock *N : SortedBBsToSinkInto) { + if (N == MoveBB) + continue; + // Clone I and replace its uses. + Instruction *IC = I.clone(); + IC->setName(I.getName()); + IC->insertBefore(&*N->getFirstInsertionPt()); + // Replaces uses of I with IC in N + for (Value::use_iterator UI = I.use_begin(), UE = I.use_end(); UI != UE;) { + Use &U = *UI++; + auto *I = cast<Instruction>(U.getUser()); + if (I->getParent() == N) + U.set(IC); + } + // Replaces uses of I with IC in blocks dominated by N + replaceDominatedUsesWith(&I, IC, DT, N); + DEBUG(dbgs() << "Sinking a clone of " << I << " To: " << N->getName() + << '\n'); + NumLoopSunkCloned++; + } + DEBUG(dbgs() << "Sinking " << I << " To: " << MoveBB->getName() << '\n'); + NumLoopSunk++; + I.moveBefore(&*MoveBB->getFirstInsertionPt()); + + return true; +} + +/// Sinks instructions from loop's preheader to the loop body if the +/// sum frequency of inserted copy is smaller than preheader's frequency. +static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI, + DominatorTree &DT, + BlockFrequencyInfo &BFI, + ScalarEvolution *SE) { + BasicBlock *Preheader = L.getLoopPreheader(); + if (!Preheader) + return false; + + // Enable LoopSink only when runtime profile is available. + // With static profile, the sinking decision may be sub-optimal. + if (!Preheader->getParent()->getEntryCount()) + return false; + + const BlockFrequency PreheaderFreq = BFI.getBlockFreq(Preheader); + // If there are no basic blocks with lower frequency than the preheader then + // we can avoid the detailed analysis as we will never find profitable sinking + // opportunities. + if (all_of(L.blocks(), [&](const BasicBlock *BB) { + return BFI.getBlockFreq(BB) > PreheaderFreq; + })) + return false; + + bool Changed = false; + AliasSetTracker CurAST(AA); + + // Compute alias set. + for (BasicBlock *BB : L.blocks()) + CurAST.add(*BB); + + // Sort loop's basic blocks by frequency + SmallVector<BasicBlock *, 10> ColdLoopBBs; + SmallDenseMap<BasicBlock *, int, 16> LoopBlockNumber; + int i = 0; + for (BasicBlock *B : L.blocks()) + if (BFI.getBlockFreq(B) < BFI.getBlockFreq(L.getLoopPreheader())) { + ColdLoopBBs.push_back(B); + LoopBlockNumber[B] = ++i; + } + std::stable_sort(ColdLoopBBs.begin(), ColdLoopBBs.end(), + [&](BasicBlock *A, BasicBlock *B) { + return BFI.getBlockFreq(A) < BFI.getBlockFreq(B); + }); + + // Traverse preheader's instructions in reverse order becaue if A depends + // on B (A appears after B), A needs to be sinked first before B can be + // sinked. + for (auto II = Preheader->rbegin(), E = Preheader->rend(); II != E;) { + Instruction *I = &*II++; + // No need to check for instruction's operands are loop invariant. + assert(L.hasLoopInvariantOperands(I) && + "Insts in a loop's preheader should have loop invariant operands!"); + if (!canSinkOrHoistInst(*I, &AA, &DT, &L, &CurAST, nullptr)) + continue; + if (sinkInstruction(L, *I, ColdLoopBBs, LoopBlockNumber, LI, DT, BFI)) + Changed = true; + } + + if (Changed && SE) + SE->forgetLoopDispositions(&L); + return Changed; +} + +namespace { +struct LegacyLoopSinkPass : public LoopPass { + static char ID; + LegacyLoopSinkPass() : LoopPass(ID) { + initializeLegacyLoopSinkPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + + auto *SE = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); + return sinkLoopInvariantInstructions( + *L, getAnalysis<AAResultsWrapperPass>().getAAResults(), + getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), + getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(), + SE ? &SE->getSE() : nullptr); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<BlockFrequencyInfoWrapperPass>(); + getLoopAnalysisUsage(AU); + } +}; +} + +char LegacyLoopSinkPass::ID = 0; +INITIALIZE_PASS_BEGIN(LegacyLoopSinkPass, "loop-sink", "Loop Sink", false, + false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) +INITIALIZE_PASS_END(LegacyLoopSinkPass, "loop-sink", "Loop Sink", false, false) + +Pass *llvm::createLoopSinkPass() { return new LegacyLoopSinkPass(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 70bd9d3..194587a 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -53,29 +53,64 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopStrengthReduce.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/IVUsers.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/ScalarEvolutionNormalization.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" +#include "llvm/IR/OperandTraits.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <iterator> +#include <map> +#include <tuple> +#include <utility> + using namespace llvm; #define DEBUG_TYPE "loop-reduce" @@ -123,8 +158,9 @@ struct MemAccessTy { bool operator!=(MemAccessTy Other) const { return !(*this == Other); } - static MemAccessTy getUnknown(LLVMContext &Ctx) { - return MemAccessTy(Type::getVoidTy(Ctx), UnknownAddressSpace); + static MemAccessTy getUnknown(LLVMContext &Ctx, + unsigned AS = UnknownAddressSpace) { + return MemAccessTy(Type::getVoidTy(Ctx), AS); } }; @@ -139,7 +175,7 @@ public: void dump() const; }; -} +} // end anonymous namespace void RegSortData::print(raw_ostream &OS) const { OS << "[NumUses=" << UsedByIndices.count() << ']'; @@ -178,7 +214,7 @@ public: const_iterator end() const { return RegSequence.end(); } }; -} +} // end anonymous namespace void RegUseTracker::countRegister(const SCEV *Reg, size_t LUIdx) { @@ -210,7 +246,7 @@ RegUseTracker::swapAndDropUse(size_t LUIdx, size_t LastLUIdx) { SmallBitVector &UsedByIndices = Pair.second.UsedByIndices; if (LUIdx < UsedByIndices.size()) UsedByIndices[LUIdx] = - LastLUIdx < UsedByIndices.size() ? UsedByIndices[LastLUIdx] : 0; + LastLUIdx < UsedByIndices.size() ? UsedByIndices[LastLUIdx] : false; UsedByIndices.resize(std::min(UsedByIndices.size(), LastLUIdx)); } } @@ -301,7 +337,7 @@ struct Formula { void dump() const; }; -} +} // end anonymous namespace /// Recursion helper for initialMatch. static void DoInitialMatch(const SCEV *S, Loop *L, @@ -323,7 +359,7 @@ static void DoInitialMatch(const SCEV *S, Loop *L, // Look at addrec operands. if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) - if (!AR->getStart()->isZero()) { + if (!AR->getStart()->isZero() && AR->isAffine()) { DoInitialMatch(AR->getStart(), L, Good, Bad, SE); DoInitialMatch(SE.getAddRecExpr(SE.getConstant(AR->getType(), 0), AR->getStepRecurrence(SE), @@ -446,8 +482,7 @@ void Formula::deleteBaseReg(const SCEV *&S) { /// Test if this formula references the given register. bool Formula::referencesReg(const SCEV *S) const { - return S == ScaledReg || - std::find(BaseRegs.begin(), BaseRegs.end(), S) != BaseRegs.end(); + return S == ScaledReg || is_contained(BaseRegs, S); } /// Test whether this formula uses registers which are used by uses other than @@ -567,7 +602,7 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS, // Distribute the sdiv over addrec operands, if the addrec doesn't overflow. if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS)) { - if (IgnoreSignificantBits || isAddRecSExtable(AR, SE)) { + if ((IgnoreSignificantBits || isAddRecSExtable(AR, SE)) && AR->isAffine()) { const SCEV *Step = getExactSDiv(AR->getStepRecurrence(SE), RHS, SE, IgnoreSignificantBits); if (!Step) return nullptr; @@ -822,8 +857,10 @@ DeleteTriviallyDeadInstructions(SmallVectorImpl<WeakVH> &DeadInsts) { } namespace { + class LSRUse; -} + +} // end anonymous namespace /// \brief Check if the addressing mode defined by \p F is completely /// folded in \p LU at isel time. @@ -883,7 +920,6 @@ public: SmallPtrSetImpl<const SCEV *> &Regs, const DenseSet<const SCEV *> &VisitedRegs, const Loop *L, - const SmallVectorImpl<int64_t> &Offsets, ScalarEvolution &SE, DominatorTree &DT, const LSRUse &LU, SmallPtrSetImpl<const SCEV *> *LoserRegs = nullptr); @@ -902,8 +938,144 @@ private: ScalarEvolution &SE, DominatorTree &DT, SmallPtrSetImpl<const SCEV *> *LoserRegs); }; + +/// An operand value in an instruction which is to be replaced with some +/// equivalent, possibly strength-reduced, replacement. +struct LSRFixup { + /// The instruction which will be updated. + Instruction *UserInst; -} + /// The operand of the instruction which will be replaced. The operand may be + /// used more than once; every instance will be replaced. + Value *OperandValToReplace; + + /// If this user is to use the post-incremented value of an induction + /// variable, this variable is non-null and holds the loop associated with the + /// induction variable. + PostIncLoopSet PostIncLoops; + + /// A constant offset to be added to the LSRUse expression. This allows + /// multiple fixups to share the same LSRUse with different offsets, for + /// example in an unrolled loop. + int64_t Offset; + + bool isUseFullyOutsideLoop(const Loop *L) const; + + LSRFixup(); + + void print(raw_ostream &OS) const; + void dump() const; +}; + +/// A DenseMapInfo implementation for holding DenseMaps and DenseSets of sorted +/// SmallVectors of const SCEV*. +struct UniquifierDenseMapInfo { + static SmallVector<const SCEV *, 4> getEmptyKey() { + SmallVector<const SCEV *, 4> V; + V.push_back(reinterpret_cast<const SCEV *>(-1)); + return V; + } + + static SmallVector<const SCEV *, 4> getTombstoneKey() { + SmallVector<const SCEV *, 4> V; + V.push_back(reinterpret_cast<const SCEV *>(-2)); + return V; + } + + static unsigned getHashValue(const SmallVector<const SCEV *, 4> &V) { + return static_cast<unsigned>(hash_combine_range(V.begin(), V.end())); + } + + static bool isEqual(const SmallVector<const SCEV *, 4> &LHS, + const SmallVector<const SCEV *, 4> &RHS) { + return LHS == RHS; + } +}; + +/// This class holds the state that LSR keeps for each use in IVUsers, as well +/// as uses invented by LSR itself. It includes information about what kinds of +/// things can be folded into the user, information about the user itself, and +/// information about how the use may be satisfied. TODO: Represent multiple +/// users of the same expression in common? +class LSRUse { + DenseSet<SmallVector<const SCEV *, 4>, UniquifierDenseMapInfo> Uniquifier; + +public: + /// An enum for a kind of use, indicating what types of scaled and immediate + /// operands it might support. + enum KindType { + Basic, ///< A normal use, with no folding. + Special, ///< A special case of basic, allowing -1 scales. + Address, ///< An address use; folding according to TargetLowering + ICmpZero ///< An equality icmp with both operands folded into one. + // TODO: Add a generic icmp too? + }; + + typedef PointerIntPair<const SCEV *, 2, KindType> SCEVUseKindPair; + + KindType Kind; + MemAccessTy AccessTy; + + /// The list of operands which are to be replaced. + SmallVector<LSRFixup, 8> Fixups; + + /// Keep track of the min and max offsets of the fixups. + int64_t MinOffset; + int64_t MaxOffset; + + /// This records whether all of the fixups using this LSRUse are outside of + /// the loop, in which case some special-case heuristics may be used. + bool AllFixupsOutsideLoop; + + /// RigidFormula is set to true to guarantee that this use will be associated + /// with a single formula--the one that initially matched. Some SCEV + /// expressions cannot be expanded. This allows LSR to consider the registers + /// used by those expressions without the need to expand them later after + /// changing the formula. + bool RigidFormula; + + /// This records the widest use type for any fixup using this + /// LSRUse. FindUseWithSimilarFormula can't consider uses with different max + /// fixup widths to be equivalent, because the narrower one may be relying on + /// the implicit truncation to truncate away bogus bits. + Type *WidestFixupType; + + /// A list of ways to build a value that can satisfy this user. After the + /// list is populated, one of these is selected heuristically and used to + /// formulate a replacement for OperandValToReplace in UserInst. + SmallVector<Formula, 12> Formulae; + + /// The set of register candidates used by all formulae in this LSRUse. + SmallPtrSet<const SCEV *, 4> Regs; + + LSRUse(KindType K, MemAccessTy AT) + : Kind(K), AccessTy(AT), MinOffset(INT64_MAX), MaxOffset(INT64_MIN), + AllFixupsOutsideLoop(true), RigidFormula(false), + WidestFixupType(nullptr) {} + + LSRFixup &getNewFixup() { + Fixups.push_back(LSRFixup()); + return Fixups.back(); + } + + void pushFixup(LSRFixup &f) { + Fixups.push_back(f); + if (f.Offset > MaxOffset) + MaxOffset = f.Offset; + if (f.Offset < MinOffset) + MinOffset = f.Offset; + } + + bool HasFormulaWithSameRegs(const Formula &F) const; + bool InsertFormula(const Formula &F); + void DeleteFormula(Formula &F); + void RecomputeRegs(size_t LUIdx, RegUseTracker &Reguses); + + void print(raw_ostream &OS) const; + void dump() const; +}; + +} // end anonymous namespace /// Tally up interesting quantities from the given register. void Cost::RateRegister(const SCEV *Reg, @@ -975,7 +1147,6 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, SmallPtrSetImpl<const SCEV *> &Regs, const DenseSet<const SCEV *> &VisitedRegs, const Loop *L, - const SmallVectorImpl<int64_t> &Offsets, ScalarEvolution &SE, DominatorTree &DT, const LSRUse &LU, SmallPtrSetImpl<const SCEV *> *LoserRegs) { @@ -1013,13 +1184,20 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, ScaleCost += getScalingFactorCost(TTI, LU, F); // Tally up the non-zero immediates. - for (int64_t O : Offsets) { + for (const LSRFixup &Fixup : LU.Fixups) { + int64_t O = Fixup.Offset; int64_t Offset = (uint64_t)O + F.BaseOffset; if (F.BaseGV) ImmCost += 64; // Handle symbolic values conservatively. // TODO: This should probably be the pointer size. else if (Offset != 0) ImmCost += APInt(64, Offset, true).getMinSignedBits(); + + // Check with target if this offset with this instruction is + // specifically not supported. + if ((isa<LoadInst>(Fixup.UserInst) || isa<StoreInst>(Fixup.UserInst)) && + !TTI.isFoldableMemAccessOffset(Fixup.UserInst, Offset)) + NumBaseAdds++; } assert(isValid() && "invalid cost"); } @@ -1066,44 +1244,8 @@ void Cost::dump() const { print(errs()); errs() << '\n'; } -namespace { - -/// An operand value in an instruction which is to be replaced with some -/// equivalent, possibly strength-reduced, replacement. -struct LSRFixup { - /// The instruction which will be updated. - Instruction *UserInst; - - /// The operand of the instruction which will be replaced. The operand may be - /// used more than once; every instance will be replaced. - Value *OperandValToReplace; - - /// If this user is to use the post-incremented value of an induction - /// variable, this variable is non-null and holds the loop associated with the - /// induction variable. - PostIncLoopSet PostIncLoops; - - /// The index of the LSRUse describing the expression which this fixup needs, - /// minus an offset (below). - size_t LUIdx; - - /// A constant offset to be added to the LSRUse expression. This allows - /// multiple fixups to share the same LSRUse with different offsets, for - /// example in an unrolled loop. - int64_t Offset; - - bool isUseFullyOutsideLoop(const Loop *L) const; - - LSRFixup(); - - void print(raw_ostream &OS) const; - void dump() const; -}; - -} - LSRFixup::LSRFixup() - : UserInst(nullptr), OperandValToReplace(nullptr), LUIdx(~size_t(0)), + : UserInst(nullptr), OperandValToReplace(nullptr), Offset(0) {} /// Test whether this fixup always uses its value outside of the given loop. @@ -1139,9 +1281,6 @@ void LSRFixup::print(raw_ostream &OS) const { PIL->getHeader()->printAsOperand(OS, /*PrintType=*/false); } - if (LUIdx != ~size_t(0)) - OS << ", LUIdx=" << LUIdx; - if (Offset != 0) OS << ", Offset=" << Offset; } @@ -1151,102 +1290,6 @@ void LSRFixup::dump() const { print(errs()); errs() << '\n'; } -namespace { - -/// A DenseMapInfo implementation for holding DenseMaps and DenseSets of sorted -/// SmallVectors of const SCEV*. -struct UniquifierDenseMapInfo { - static SmallVector<const SCEV *, 4> getEmptyKey() { - SmallVector<const SCEV *, 4> V; - V.push_back(reinterpret_cast<const SCEV *>(-1)); - return V; - } - - static SmallVector<const SCEV *, 4> getTombstoneKey() { - SmallVector<const SCEV *, 4> V; - V.push_back(reinterpret_cast<const SCEV *>(-2)); - return V; - } - - static unsigned getHashValue(const SmallVector<const SCEV *, 4> &V) { - return static_cast<unsigned>(hash_combine_range(V.begin(), V.end())); - } - - static bool isEqual(const SmallVector<const SCEV *, 4> &LHS, - const SmallVector<const SCEV *, 4> &RHS) { - return LHS == RHS; - } -}; - -/// This class holds the state that LSR keeps for each use in IVUsers, as well -/// as uses invented by LSR itself. It includes information about what kinds of -/// things can be folded into the user, information about the user itself, and -/// information about how the use may be satisfied. TODO: Represent multiple -/// users of the same expression in common? -class LSRUse { - DenseSet<SmallVector<const SCEV *, 4>, UniquifierDenseMapInfo> Uniquifier; - -public: - /// An enum for a kind of use, indicating what types of scaled and immediate - /// operands it might support. - enum KindType { - Basic, ///< A normal use, with no folding. - Special, ///< A special case of basic, allowing -1 scales. - Address, ///< An address use; folding according to TargetLowering - ICmpZero ///< An equality icmp with both operands folded into one. - // TODO: Add a generic icmp too? - }; - - typedef PointerIntPair<const SCEV *, 2, KindType> SCEVUseKindPair; - - KindType Kind; - MemAccessTy AccessTy; - - SmallVector<int64_t, 8> Offsets; - int64_t MinOffset; - int64_t MaxOffset; - - /// This records whether all of the fixups using this LSRUse are outside of - /// the loop, in which case some special-case heuristics may be used. - bool AllFixupsOutsideLoop; - - /// RigidFormula is set to true to guarantee that this use will be associated - /// with a single formula--the one that initially matched. Some SCEV - /// expressions cannot be expanded. This allows LSR to consider the registers - /// used by those expressions without the need to expand them later after - /// changing the formula. - bool RigidFormula; - - /// This records the widest use type for any fixup using this - /// LSRUse. FindUseWithSimilarFormula can't consider uses with different max - /// fixup widths to be equivalent, because the narrower one may be relying on - /// the implicit truncation to truncate away bogus bits. - Type *WidestFixupType; - - /// A list of ways to build a value that can satisfy this user. After the - /// list is populated, one of these is selected heuristically and used to - /// formulate a replacement for OperandValToReplace in UserInst. - SmallVector<Formula, 12> Formulae; - - /// The set of register candidates used by all formulae in this LSRUse. - SmallPtrSet<const SCEV *, 4> Regs; - - LSRUse(KindType K, MemAccessTy AT) - : Kind(K), AccessTy(AT), MinOffset(INT64_MAX), MaxOffset(INT64_MIN), - AllFixupsOutsideLoop(true), RigidFormula(false), - WidestFixupType(nullptr) {} - - bool HasFormulaWithSameRegs(const Formula &F) const; - bool InsertFormula(const Formula &F); - void DeleteFormula(Formula &F); - void RecomputeRegs(size_t LUIdx, RegUseTracker &Reguses); - - void print(raw_ostream &OS) const; - void dump() const; -}; - -} - /// Test whether this use as a formula which has the same registers as the given /// formula. bool LSRUse::HasFormulaWithSameRegs(const Formula &F) const { @@ -1334,9 +1377,9 @@ void LSRUse::print(raw_ostream &OS) const { OS << ", Offsets={"; bool NeedComma = false; - for (int64_t O : Offsets) { + for (const LSRFixup &Fixup : Fixups) { if (NeedComma) OS << ','; - OS << O; + OS << Fixup.Offset; NeedComma = true; } OS << '}'; @@ -1638,14 +1681,16 @@ class LSRInstance { Instruction *IVIncInsertPos; /// Interesting factors between use strides. - SmallSetVector<int64_t, 8> Factors; + /// + /// We explicitly use a SetVector which contains a SmallSet, instead of the + /// default, a SmallDenseSet, because we need to use the full range of + /// int64_ts, and there's currently no good way of doing that with + /// SmallDenseSet. + SetVector<int64_t, SmallVector<int64_t, 8>, SmallSet<int64_t, 8>> Factors; /// Interesting use types, to facilitate truncation reuse. SmallSetVector<Type *, 4> Types; - /// The list of operands which are to be replaced. - SmallVector<LSRFixup, 16> Fixups; - /// The list of interesting uses. SmallVector<LSRUse, 16> Uses; @@ -1678,11 +1723,6 @@ class LSRInstance { void CollectInterestingTypesAndFactors(); void CollectFixupsAndInitialFormulae(); - LSRFixup &getNewFixup() { - Fixups.push_back(LSRFixup()); - return Fixups.back(); - } - // Support for sharing of LSRUses between LSRFixups. typedef DenseMap<LSRUse::SCEVUseKindPair, size_t> UseMapTy; UseMapTy UseMap; @@ -1752,16 +1792,16 @@ class LSRInstance { const LSRUse &LU, SCEVExpander &Rewriter) const; - Value *Expand(const LSRFixup &LF, + Value *Expand(const LSRUse &LU, const LSRFixup &LF, const Formula &F, BasicBlock::iterator IP, SCEVExpander &Rewriter, SmallVectorImpl<WeakVH> &DeadInsts) const; - void RewriteForPHI(PHINode *PN, const LSRFixup &LF, + void RewriteForPHI(PHINode *PN, const LSRUse &LU, const LSRFixup &LF, const Formula &F, SCEVExpander &Rewriter, SmallVectorImpl<WeakVH> &DeadInsts) const; - void Rewrite(const LSRFixup &LF, + void Rewrite(const LSRUse &LU, const LSRFixup &LF, const Formula &F, SCEVExpander &Rewriter, SmallVectorImpl<WeakVH> &DeadInsts) const; @@ -1780,7 +1820,7 @@ public: void dump() const; }; -} +} // end anonymous namespace /// If IV is used in a int-to-float cast inside the loop then try to eliminate /// the cast operation. @@ -2068,10 +2108,30 @@ void LSRInstance::OptimizeLoopTermCond() { SmallPtrSet<Instruction *, 4> PostIncs; + // We need a different set of heuristics for rotated and non-rotated loops. + // If a loop is rotated then the latch is also the backedge, so inserting + // post-inc expressions just before the latch is ideal. To reduce live ranges + // it also makes sense to rewrite terminating conditions to use post-inc + // expressions. + // + // If the loop is not rotated then the latch is not a backedge; the latch + // check is done in the loop head. Adding post-inc expressions before the + // latch will cause overlapping live-ranges of pre-inc and post-inc expressions + // in the loop body. In this case we do *not* want to use post-inc expressions + // in the latch check, and we want to insert post-inc expressions before + // the backedge. BasicBlock *LatchBlock = L->getLoopLatch(); SmallVector<BasicBlock*, 8> ExitingBlocks; L->getExitingBlocks(ExitingBlocks); + if (llvm::all_of(ExitingBlocks, [&LatchBlock](const BasicBlock *BB) { + return LatchBlock != BB; + })) { + // The backedge doesn't exit the loop; treat this as a head-tested loop. + IVIncInsertPos = LatchBlock->getTerminator(); + return; + } + // Otherwise treat this as a rotated loop. for (BasicBlock *ExitingBlock : ExitingBlocks) { // Get the terminating condition for the loop if possible. If we @@ -2220,8 +2280,10 @@ bool LSRInstance::reconcileNewOffset(LSRUse &LU, int64_t NewOffset, // TODO: Be less conservative when the type is similar and can use the same // addressing modes. if (Kind == LSRUse::Address) { - if (AccessTy != LU.AccessTy) - NewAccessTy = MemAccessTy::getUnknown(AccessTy.MemTy->getContext()); + if (AccessTy.MemTy != LU.AccessTy.MemTy) { + NewAccessTy = MemAccessTy::getUnknown(AccessTy.MemTy->getContext(), + AccessTy.AddrSpace); + } } // Conservatively assume HasBaseReg is true for now. @@ -2241,8 +2303,6 @@ bool LSRInstance::reconcileNewOffset(LSRUse &LU, int64_t NewOffset, LU.MinOffset = NewMinOffset; LU.MaxOffset = NewMaxOffset; LU.AccessTy = NewAccessTy; - if (NewOffset != LU.Offsets.back()) - LU.Offsets.push_back(NewOffset); return true; } @@ -2279,11 +2339,6 @@ std::pair<size_t, int64_t> LSRInstance::getUse(const SCEV *&Expr, Uses.push_back(LSRUse(Kind, AccessTy)); LSRUse &LU = Uses[LUIdx]; - // We don't need to track redundant offsets, but we don't need to go out - // of our way here to avoid them. - if (LU.Offsets.empty() || Offset != LU.Offsets.back()) - LU.Offsets.push_back(Offset); - LU.MinOffset = Offset; LU.MaxOffset = Offset; return std::make_pair(LUIdx, Offset); @@ -2500,7 +2555,7 @@ bool IVChain::isProfitableIncrement(const SCEV *OperExpr, if (!isa<SCEVConstant>(IncExpr)) { const SCEV *HeadExpr = SE.getSCEV(getWideOperand(Incs[0].IVOperand)); if (isa<SCEVConstant>(SE.getMinusSCEV(OperExpr, HeadExpr))) - return 0; + return false; } SmallPtrSet<const SCEV*, 8> Processed; @@ -2797,9 +2852,8 @@ void LSRInstance::FinalizeChain(IVChain &Chain) { DEBUG(dbgs() << "Final Chain: " << *Chain.Incs[0].UserInst << "\n"); for (const IVInc &Inc : Chain) { - DEBUG(dbgs() << " Inc: " << Inc.UserInst << "\n"); - auto UseI = std::find(Inc.UserInst->op_begin(), Inc.UserInst->op_end(), - Inc.IVOperand); + DEBUG(dbgs() << " Inc: " << *Inc.UserInst << "\n"); + auto UseI = find(Inc.UserInst->operands(), Inc.IVOperand); assert(UseI != Inc.UserInst->op_end() && "cannot find IV operand"); IVIncSet.insert(UseI); } @@ -2932,39 +2986,34 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { for (const IVStrideUse &U : IU) { Instruction *UserInst = U.getUser(); // Skip IV users that are part of profitable IV Chains. - User::op_iterator UseI = std::find(UserInst->op_begin(), UserInst->op_end(), - U.getOperandValToReplace()); + User::op_iterator UseI = + find(UserInst->operands(), U.getOperandValToReplace()); assert(UseI != UserInst->op_end() && "cannot find IV operand"); if (IVIncSet.count(UseI)) continue; - // Record the uses. - LSRFixup &LF = getNewFixup(); - LF.UserInst = UserInst; - LF.OperandValToReplace = U.getOperandValToReplace(); - LF.PostIncLoops = U.getPostIncLoops(); - LSRUse::KindType Kind = LSRUse::Basic; MemAccessTy AccessTy; - if (isAddressUse(LF.UserInst, LF.OperandValToReplace)) { + if (isAddressUse(UserInst, U.getOperandValToReplace())) { Kind = LSRUse::Address; - AccessTy = getAccessType(LF.UserInst); + AccessTy = getAccessType(UserInst); } const SCEV *S = IU.getExpr(U); - + PostIncLoopSet TmpPostIncLoops = U.getPostIncLoops(); + // Equality (== and !=) ICmps are special. We can rewrite (i == N) as // (N - i == 0), and this allows (N - i) to be the expression that we work // with rather than just N or i, so we can consider the register // requirements for both N and i at the same time. Limiting this code to // equality icmps is not a problem because all interesting loops use // equality icmps, thanks to IndVarSimplify. - if (ICmpInst *CI = dyn_cast<ICmpInst>(LF.UserInst)) + if (ICmpInst *CI = dyn_cast<ICmpInst>(UserInst)) if (CI->isEquality()) { // Swap the operands if needed to put the OperandValToReplace on the // left, for consistency. Value *NV = CI->getOperand(1); - if (NV == LF.OperandValToReplace) { + if (NV == U.getOperandValToReplace()) { CI->setOperand(1, CI->getOperand(0)); CI->setOperand(0, NV); NV = CI->getOperand(1); @@ -2977,7 +3026,7 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { // S is normalized, so normalize N before folding it into S // to keep the result normalized. N = TransformForPostIncUse(Normalize, N, CI, nullptr, - LF.PostIncLoops, SE, DT); + TmpPostIncLoops, SE, DT); Kind = LSRUse::ICmpZero; S = SE.getMinusSCEV(N, S); } @@ -2990,12 +3039,20 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { Factors.insert(-1); } - // Set up the initial formula for this use. + // Get or create an LSRUse. std::pair<size_t, int64_t> P = getUse(S, Kind, AccessTy); - LF.LUIdx = P.first; - LF.Offset = P.second; - LSRUse &LU = Uses[LF.LUIdx]; + size_t LUIdx = P.first; + int64_t Offset = P.second; + LSRUse &LU = Uses[LUIdx]; + + // Record the fixup. + LSRFixup &LF = LU.getNewFixup(); + LF.UserInst = UserInst; + LF.OperandValToReplace = U.getOperandValToReplace(); + LF.PostIncLoops = TmpPostIncLoops; + LF.Offset = Offset; LU.AllFixupsOutsideLoop &= LF.isUseFullyOutsideLoop(L); + if (!LU.WidestFixupType || SE.getTypeSizeInBits(LU.WidestFixupType) < SE.getTypeSizeInBits(LF.OperandValToReplace->getType())) @@ -3003,8 +3060,8 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { // If this is the first use of this LSRUse, give it a formula. if (LU.Formulae.empty()) { - InsertInitialFormula(S, LU, LF.LUIdx); - CountRegisters(LU.Formulae.back(), LF.LUIdx); + InsertInitialFormula(S, LU, LUIdx); + CountRegisters(LU.Formulae.back(), LUIdx); } } @@ -3109,6 +3166,9 @@ LSRInstance::CollectLoopInvariantFixupsAndFormulae() { // Don't bother if the instruction is in a BB which ends in an EHPad. if (UseBB->getTerminator()->isEHPad()) continue; + // Don't bother rewriting PHIs in catchswitch blocks. + if (isa<CatchSwitchInst>(UserInst->getParent()->getTerminator())) + continue; // Ignore uses which are part of other SCEV expressions, to avoid // analyzing them multiple times. if (SE.isSCEVable(UserInst->getType())) { @@ -3130,20 +3190,21 @@ LSRInstance::CollectLoopInvariantFixupsAndFormulae() { continue; } - LSRFixup &LF = getNewFixup(); - LF.UserInst = const_cast<Instruction *>(UserInst); - LF.OperandValToReplace = U; std::pair<size_t, int64_t> P = getUse( S, LSRUse::Basic, MemAccessTy()); - LF.LUIdx = P.first; - LF.Offset = P.second; - LSRUse &LU = Uses[LF.LUIdx]; + size_t LUIdx = P.first; + int64_t Offset = P.second; + LSRUse &LU = Uses[LUIdx]; + LSRFixup &LF = LU.getNewFixup(); + LF.UserInst = const_cast<Instruction *>(UserInst); + LF.OperandValToReplace = U; + LF.Offset = Offset; LU.AllFixupsOutsideLoop &= LF.isUseFullyOutsideLoop(L); if (!LU.WidestFixupType || SE.getTypeSizeInBits(LU.WidestFixupType) < SE.getTypeSizeInBits(LF.OperandValToReplace->getType())) LU.WidestFixupType = LF.OperandValToReplace->getType(); - InsertSupplementalFormula(US, LU, LF.LUIdx); + InsertSupplementalFormula(US, LU, LUIdx); CountRegisters(LU.Formulae.back(), Uses.size() - 1); break; } @@ -3175,7 +3236,7 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C, return nullptr; } else if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { // Split a non-zero base out of an addrec. - if (AR->getStart()->isZero()) + if (AR->getStart()->isZero() || !AR->isAffine()) return S; const SCEV *Remainder = CollectSubexprs(AR->getStart(), @@ -3629,7 +3690,7 @@ struct WorkItem { void dump() const; }; -} +} // end anonymous namespace void WorkItem::print(raw_ostream &OS) const { OS << "in formulae referencing " << *OrigReg << " in use " << LUIdx @@ -3872,8 +3933,7 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() { // the corresponding bad register from the Regs set. Cost CostF; Regs.clear(); - CostF.RateFormula(TTI, F, Regs, VisitedRegs, L, LU.Offsets, SE, DT, LU, - &LoserRegs); + CostF.RateFormula(TTI, F, Regs, VisitedRegs, L, SE, DT, LU, &LoserRegs); if (CostF.isLoser()) { // During initial formula generation, undesirable formulae are generated // by uses within other loops that have some non-trivial address mode or @@ -3906,8 +3966,7 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() { Cost CostBest; Regs.clear(); - CostBest.RateFormula(TTI, Best, Regs, VisitedRegs, L, LU.Offsets, SE, - DT, LU); + CostBest.RateFormula(TTI, Best, Regs, VisitedRegs, L, SE, DT, LU); if (CostF < CostBest) std::swap(F, Best); DEBUG(dbgs() << " Filtering out formula "; F.print(dbgs()); @@ -4053,25 +4112,13 @@ void LSRInstance::NarrowSearchSpaceByCollapsingUnrolledCode() { LUThatHas->AllFixupsOutsideLoop &= LU.AllFixupsOutsideLoop; - // Update the relocs to reference the new use. - for (LSRFixup &Fixup : Fixups) { - if (Fixup.LUIdx == LUIdx) { - Fixup.LUIdx = LUThatHas - &Uses.front(); - Fixup.Offset += F.BaseOffset; - // Add the new offset to LUThatHas' offset list. - if (LUThatHas->Offsets.back() != Fixup.Offset) { - LUThatHas->Offsets.push_back(Fixup.Offset); - if (Fixup.Offset > LUThatHas->MaxOffset) - LUThatHas->MaxOffset = Fixup.Offset; - if (Fixup.Offset < LUThatHas->MinOffset) - LUThatHas->MinOffset = Fixup.Offset; - } - DEBUG(dbgs() << "New fixup has offset " << Fixup.Offset << '\n'); - } - if (Fixup.LUIdx == NumUses-1) - Fixup.LUIdx = LUIdx; + // Transfer the fixups of LU to LUThatHas. + for (LSRFixup &Fixup : LU.Fixups) { + Fixup.Offset += F.BaseOffset; + LUThatHas->pushFixup(Fixup); + DEBUG(dbgs() << "New fixup has offset " << Fixup.Offset << '\n'); } - + // Delete formulae from the new use which are no longer legal. bool Any = false; for (size_t i = 0, e = LUThatHas->Formulae.size(); i != e; ++i) { @@ -4137,9 +4184,10 @@ void LSRInstance::NarrowSearchSpaceByPickingWinnerRegs() { for (const SCEV *Reg : RegUses) { if (Taken.count(Reg)) continue; - if (!Best) + if (!Best) { Best = Reg; - else { + BestNum = RegUses.getUsedByIndices(Reg).count(); + } else { unsigned Count = RegUses.getUsedByIndices(Reg).count(); if (Count > BestNum) { Best = Reg; @@ -4229,8 +4277,7 @@ void LSRInstance::SolveRecurse(SmallVectorImpl<const Formula *> &Solution, int NumReqRegsToFind = std::min(F.getNumRegs(), ReqRegs.size()); for (const SCEV *Reg : ReqRegs) { if ((F.ScaledReg && F.ScaledReg == Reg) || - std::find(F.BaseRegs.begin(), F.BaseRegs.end(), Reg) != - F.BaseRegs.end()) { + is_contained(F.BaseRegs, Reg)) { --NumReqRegsToFind; if (NumReqRegsToFind == 0) break; @@ -4246,8 +4293,7 @@ void LSRInstance::SolveRecurse(SmallVectorImpl<const Formula *> &Solution, // the current best, prune the search at that point. NewCost = CurCost; NewRegs = CurRegs; - NewCost.RateFormula(TTI, F, NewRegs, VisitedRegs, L, LU.Offsets, SE, DT, - LU); + NewCost.RateFormula(TTI, F, NewRegs, VisitedRegs, L, SE, DT, LU); if (NewCost < SolutionCost) { Workspace.push_back(&F); if (Workspace.size() != Uses.size()) { @@ -4313,7 +4359,7 @@ LSRInstance::HoistInsertPosition(BasicBlock::iterator IP, const SmallVectorImpl<Instruction *> &Inputs) const { Instruction *Tentative = &*IP; - for (;;) { + while (true) { bool AllDominate = true; Instruction *BetterPos = nullptr; // Don't bother attempting to insert before a catchswitch, their basic block @@ -4430,12 +4476,12 @@ LSRInstance::AdjustInsertPositionForExpand(BasicBlock::iterator LowestIP, /// Emit instructions for the leading candidate expression for this LSRUse (this /// is called "expanding"). -Value *LSRInstance::Expand(const LSRFixup &LF, +Value *LSRInstance::Expand(const LSRUse &LU, + const LSRFixup &LF, const Formula &F, BasicBlock::iterator IP, SCEVExpander &Rewriter, SmallVectorImpl<WeakVH> &DeadInsts) const { - const LSRUse &LU = Uses[LF.LUIdx]; if (LU.RigidFormula) return LF.OperandValToReplace; @@ -4617,6 +4663,7 @@ Value *LSRInstance::Expand(const LSRFixup &LF, /// effectively happens in their predecessor blocks, so the expression may need /// to be expanded in multiple places. void LSRInstance::RewriteForPHI(PHINode *PN, + const LSRUse &LU, const LSRFixup &LF, const Formula &F, SCEVExpander &Rewriter, @@ -4631,7 +4678,8 @@ void LSRInstance::RewriteForPHI(PHINode *PN, // is the canonical backedge for this loop, which complicates post-inc // users. if (e != 1 && BB->getTerminator()->getNumSuccessors() > 1 && - !isa<IndirectBrInst>(BB->getTerminator())) { + !isa<IndirectBrInst>(BB->getTerminator()) && + !isa<CatchSwitchInst>(BB->getTerminator())) { BasicBlock *Parent = PN->getParent(); Loop *PNLoop = LI.getLoopFor(Parent); if (!PNLoop || Parent != PNLoop->getHeader()) { @@ -4670,7 +4718,7 @@ void LSRInstance::RewriteForPHI(PHINode *PN, if (!Pair.second) PN->setIncomingValue(i, Pair.first->second); else { - Value *FullV = Expand(LF, F, BB->getTerminator()->getIterator(), + Value *FullV = Expand(LU, LF, F, BB->getTerminator()->getIterator(), Rewriter, DeadInsts); // If this is reuse-by-noop-cast, insert the noop cast. @@ -4691,17 +4739,18 @@ void LSRInstance::RewriteForPHI(PHINode *PN, /// Emit instructions for the leading candidate expression for this LSRUse (this /// is called "expanding"), and update the UserInst to reference the newly /// expanded value. -void LSRInstance::Rewrite(const LSRFixup &LF, +void LSRInstance::Rewrite(const LSRUse &LU, + const LSRFixup &LF, const Formula &F, SCEVExpander &Rewriter, SmallVectorImpl<WeakVH> &DeadInsts) const { // First, find an insertion point that dominates UserInst. For PHI nodes, // find the nearest block which dominates all the relevant uses. if (PHINode *PN = dyn_cast<PHINode>(LF.UserInst)) { - RewriteForPHI(PN, LF, F, Rewriter, DeadInsts); + RewriteForPHI(PN, LU, LF, F, Rewriter, DeadInsts); } else { Value *FullV = - Expand(LF, F, LF.UserInst->getIterator(), Rewriter, DeadInsts); + Expand(LU, LF, F, LF.UserInst->getIterator(), Rewriter, DeadInsts); // If this is reuse-by-noop-cast, insert the noop cast. Type *OpTy = LF.OperandValToReplace->getType(); @@ -4717,7 +4766,7 @@ void LSRInstance::Rewrite(const LSRFixup &LF, // its new value may happen to be equal to LF.OperandValToReplace, in // which case doing replaceUsesOfWith leads to replacing both operands // with the same value. TODO: Reorganize this. - if (Uses[LF.LUIdx].Kind == LSRUse::ICmpZero) + if (LU.Kind == LSRUse::ICmpZero) LF.UserInst->setOperand(0, FullV); else LF.UserInst->replaceUsesOfWith(LF.OperandValToReplace, FullV); @@ -4750,11 +4799,11 @@ void LSRInstance::ImplementSolution( } // Expand the new value definitions and update the users. - for (const LSRFixup &Fixup : Fixups) { - Rewrite(Fixup, *Solution[Fixup.LUIdx], Rewriter, DeadInsts); - - Changed = true; - } + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) + for (const LSRFixup &Fixup : Uses[LUIdx].Fixups) { + Rewrite(Uses[LUIdx], Fixup, *Solution[LUIdx], Rewriter, DeadInsts); + Changed = true; + } for (const IVChain &Chain : IVChainVec) { GenerateIVChain(Chain, Rewriter, DeadInsts); @@ -4898,11 +4947,12 @@ void LSRInstance::print_factors_and_types(raw_ostream &OS) const { void LSRInstance::print_fixups(raw_ostream &OS) const { OS << "LSR is examining the following fixup sites:\n"; - for (const LSRFixup &LF : Fixups) { - dbgs() << " "; - LF.print(OS); - OS << '\n'; - } + for (const LSRUse &LU : Uses) + for (const LSRFixup &LF : LU.Fixups) { + dbgs() << " "; + LF.print(OS); + OS << '\n'; + } } void LSRInstance::print_uses(raw_ostream &OS) const { @@ -4935,6 +4985,7 @@ namespace { class LoopStrengthReduce : public LoopPass { public: static char ID; // Pass ID, replacement for typeid + LoopStrengthReduce(); private: @@ -4942,24 +4993,7 @@ private: void getAnalysisUsage(AnalysisUsage &AU) const override; }; -} - -char LoopStrengthReduce::ID = 0; -INITIALIZE_PASS_BEGIN(LoopStrengthReduce, "loop-reduce", - "Loop Strength Reduction", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(IVUsersWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_END(LoopStrengthReduce, "loop-reduce", - "Loop Strength Reduction", false, false) - - -Pass *llvm::createLoopStrengthReducePass() { - return new LoopStrengthReduce(); -} +} // end anonymous namespace LoopStrengthReduce::LoopStrengthReduce() : LoopPass(ID) { initializeLoopStrengthReducePass(*PassRegistry::getPassRegistry()); @@ -4985,16 +5019,9 @@ void LoopStrengthReduce::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<TargetTransformInfoWrapperPass>(); } -bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager & /*LPM*/) { - if (skipLoop(L)) - return false; - - auto &IU = getAnalysis<IVUsersWrapperPass>().getIU(); - auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI( - *L->getHeader()->getParent()); +static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, + DominatorTree &DT, LoopInfo &LI, + const TargetTransformInfo &TTI) { bool Changed = false; // Run the main LSR transformation. @@ -5005,15 +5032,11 @@ bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager & /*LPM*/) { if (EnablePhiElim && L->isLoopSimplifyForm()) { SmallVector<WeakVH, 16> DeadInsts; const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); - SCEVExpander Rewriter(getAnalysis<ScalarEvolutionWrapperPass>().getSE(), DL, - "lsr"); + SCEVExpander Rewriter(SE, DL, "lsr"); #ifndef NDEBUG Rewriter.setDebugType(DEBUG_TYPE); #endif - unsigned numFolded = Rewriter.replaceCongruentIVs( - L, &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), DeadInsts, - &getAnalysis<TargetTransformInfoWrapperPass>().getTTI( - *L->getHeader()->getParent())); + unsigned numFolded = Rewriter.replaceCongruentIVs(L, &DT, DeadInsts, &TTI); if (numFolded) { Changed = true; DeleteTriviallyDeadInstructions(DeadInsts); @@ -5022,3 +5045,40 @@ bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager & /*LPM*/) { } return Changed; } + +bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager & /*LPM*/) { + if (skipLoop(L)) + return false; + + auto &IU = getAnalysis<IVUsersWrapperPass>().getIU(); + auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI( + *L->getHeader()->getParent()); + return ReduceLoopStrength(L, IU, SE, DT, LI, TTI); +} + +PreservedAnalyses LoopStrengthReducePass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + if (!ReduceLoopStrength(&L, AM.getResult<IVUsersAnalysis>(L, AR), AR.SE, + AR.DT, AR.LI, AR.TTI)) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} + +char LoopStrengthReduce::ID = 0; +INITIALIZE_PASS_BEGIN(LoopStrengthReduce, "loop-reduce", + "Loop Strength Reduction", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(IVUsersWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_END(LoopStrengthReduce, "loop-reduce", + "Loop Strength Reduction", false, false) + +Pass *llvm::createLoopStrengthReducePass() { return new LoopStrengthReduce(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp index 91af4a1..c7f9122 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -12,6 +12,7 @@ // counts of loops easily. //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/LoopUnrollPass.h" #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" @@ -19,11 +20,10 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/LoopUnrollAnalyzer.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" -#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/IntrinsicInst.h" @@ -32,6 +32,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/UnrollLoop.h" #include <climits> @@ -45,16 +46,14 @@ static cl::opt<unsigned> UnrollThreshold("unroll-threshold", cl::Hidden, cl::desc("The baseline cost threshold for loop unrolling")); -static cl::opt<unsigned> UnrollPercentDynamicCostSavedThreshold( - "unroll-percent-dynamic-cost-saved-threshold", cl::init(50), cl::Hidden, - cl::desc("The percentage of estimated dynamic cost which must be saved by " - "unrolling to allow unrolling up to the max threshold.")); - -static cl::opt<unsigned> UnrollDynamicCostSavingsDiscount( - "unroll-dynamic-cost-savings-discount", cl::init(100), cl::Hidden, - cl::desc("This is the amount discounted from the total unroll cost when " - "the unrolled form has a high dynamic cost savings (triggered by " - "the '-unroll-perecent-dynamic-cost-saved-threshold' flag).")); +static cl::opt<unsigned> UnrollMaxPercentThresholdBoost( + "unroll-max-percent-threshold-boost", cl::init(400), cl::Hidden, + cl::desc("The maximum 'boost' (represented as a percentage >= 100) applied " + "to the threshold when aggressively unrolling a loop due to the " + "dynamic cost savings. If completely unrolling a loop will reduce " + "the total runtime from X to Y, we boost the loop unroll " + "threshold to DefaultThreshold*std::min(MaxPercentThresholdBoost, " + "X/Y). This limit avoids excessive code bloat.")); static cl::opt<unsigned> UnrollMaxIterationsCountToAnalyze( "unroll-max-iteration-count-to-analyze", cl::init(10), cl::Hidden, @@ -90,43 +89,59 @@ static cl::opt<bool> UnrollRuntime("unroll-runtime", cl::ZeroOrMore, cl::Hidden, cl::desc("Unroll loops with run-time trip counts")); +static cl::opt<unsigned> UnrollMaxUpperBound( + "unroll-max-upperbound", cl::init(8), cl::Hidden, + cl::desc( + "The max of trip count upper bound that is considered in unrolling")); + static cl::opt<unsigned> PragmaUnrollThreshold( "pragma-unroll-threshold", cl::init(16 * 1024), cl::Hidden, cl::desc("Unrolled size limit for loops with an unroll(full) or " "unroll_count pragma.")); +static cl::opt<unsigned> FlatLoopTripCountThreshold( + "flat-loop-tripcount-threshold", cl::init(5), cl::Hidden, + cl::desc("If the runtime tripcount for the loop is lower than the " + "threshold, the loop is considered as flat and will be less " + "aggressively unrolled.")); + +static cl::opt<bool> + UnrollAllowPeeling("unroll-allow-peeling", cl::Hidden, + cl::desc("Allows loops to be peeled when the dynamic " + "trip count is known to be low.")); + /// A magic value for use with the Threshold parameter to indicate /// that the loop unroll should be performed regardless of how much /// code expansion would result. static const unsigned NoThreshold = UINT_MAX; -/// Default unroll count for loops with run-time trip count if -/// -unroll-count is not set -static const unsigned DefaultUnrollRuntimeCount = 8; - /// Gather the various unrolling parameters based on the defaults, compiler /// flags, TTI overrides and user specified parameters. static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( Loop *L, const TargetTransformInfo &TTI, Optional<unsigned> UserThreshold, Optional<unsigned> UserCount, Optional<bool> UserAllowPartial, - Optional<bool> UserRuntime) { + Optional<bool> UserRuntime, Optional<bool> UserUpperBound) { TargetTransformInfo::UnrollingPreferences UP; // Set up the defaults UP.Threshold = 150; - UP.PercentDynamicCostSavedThreshold = 50; - UP.DynamicCostSavingsDiscount = 100; + UP.MaxPercentThresholdBoost = 400; UP.OptSizeThreshold = 0; UP.PartialThreshold = UP.Threshold; UP.PartialOptSizeThreshold = 0; UP.Count = 0; + UP.PeelCount = 0; + UP.DefaultUnrollRuntimeCount = 8; UP.MaxCount = UINT_MAX; UP.FullUnrollMaxCount = UINT_MAX; + UP.BEInsns = 2; UP.Partial = false; UP.Runtime = false; UP.AllowRemainder = true; UP.AllowExpensiveTripCount = false; UP.Force = false; + UP.UpperBound = false; + UP.AllowPeeling = false; // Override with any target specific settings TTI.getUnrollingPreferences(L, UP); @@ -142,11 +157,8 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( UP.Threshold = UnrollThreshold; UP.PartialThreshold = UnrollThreshold; } - if (UnrollPercentDynamicCostSavedThreshold.getNumOccurrences() > 0) - UP.PercentDynamicCostSavedThreshold = - UnrollPercentDynamicCostSavedThreshold; - if (UnrollDynamicCostSavingsDiscount.getNumOccurrences() > 0) - UP.DynamicCostSavingsDiscount = UnrollDynamicCostSavingsDiscount; + if (UnrollMaxPercentThresholdBoost.getNumOccurrences() > 0) + UP.MaxPercentThresholdBoost = UnrollMaxPercentThresholdBoost; if (UnrollMaxCount.getNumOccurrences() > 0) UP.MaxCount = UnrollMaxCount; if (UnrollFullMaxCount.getNumOccurrences() > 0) @@ -157,6 +169,10 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( UP.AllowRemainder = UnrollAllowRemainder; if (UnrollRuntime.getNumOccurrences() > 0) UP.Runtime = UnrollRuntime; + if (UnrollMaxUpperBound == 0) + UP.UpperBound = false; + if (UnrollAllowPeeling.getNumOccurrences() > 0) + UP.AllowPeeling = UnrollAllowPeeling; // Apply user values provided by argument if (UserThreshold.hasValue()) { @@ -169,6 +185,8 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( UP.Partial = *UserAllowPartial; if (UserRuntime.hasValue()) UP.Runtime = *UserRuntime; + if (UserUpperBound.hasValue()) + UP.UpperBound = *UserUpperBound; return UP; } @@ -210,11 +228,11 @@ struct UnrolledInstStateKeyInfo { namespace { struct EstimatedUnrollCost { /// \brief The estimated cost after unrolling. - int UnrolledCost; + unsigned UnrolledCost; /// \brief The estimated dynamic cost of executing the instructions in the /// rolled form. - int RolledDynamicCost; + unsigned RolledDynamicCost; }; } @@ -234,7 +252,7 @@ struct EstimatedUnrollCost { static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, ScalarEvolution &SE, const TargetTransformInfo &TTI, - int MaxUnrolledLoopSize) { + unsigned MaxUnrolledLoopSize) { // We want to be able to scale offsets by the trip count and add more offsets // to them without checking for overflows, and we already don't want to // analyze *massive* trip counts, so we force the max to be reasonably small. @@ -258,14 +276,14 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, // The estimated cost of the unrolled form of the loop. We try to estimate // this by simplifying as much as we can while computing the estimate. - int UnrolledCost = 0; + unsigned UnrolledCost = 0; // We also track the estimated dynamic (that is, actually executed) cost in // the rolled form. This helps identify cases when the savings from unrolling // aren't just exposing dead control flows, but actual reduced dynamic // instructions due to the simplifications which we expect to occur after // unrolling. - int RolledDynamicCost = 0; + unsigned RolledDynamicCost = 0; // We track the simplification of each instruction in each iteration. We use // this to recursively merge costs into the unrolled cost on-demand so that @@ -412,6 +430,9 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, // it. We don't change the actual IR, just count optimization // opportunities. for (Instruction &I : *BB) { + if (isa<DbgInfoIntrinsic>(I)) + continue; + // Track this instruction's expected baseline cost when executing the // rolled loop form. RolledDynamicCost += TTI.getUserCost(&I); @@ -429,16 +450,16 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, if (IsFree) continue; - // If the instruction might have a side-effect recursively account for - // the cost of it and all the instructions leading up to it. - if (I.mayHaveSideEffects()) - AddCostRecursively(I, Iteration); - // Can't properly model a cost of a call. // FIXME: With a proper cost model we should be able to do it. if(isa<CallInst>(&I)) return None; + // If the instruction might have a side-effect recursively account for + // the cost of it and all the instructions leading up to it. + if (I.mayHaveSideEffects()) + AddCostRecursively(I, Iteration); + // If unrolled body turns out to be too big, bail out. if (UnrolledCost > MaxUnrolledLoopSize) { DEBUG(dbgs() << " Exceeded threshold.. exiting.\n" @@ -529,7 +550,7 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, static unsigned ApproximateLoopSize(const Loop *L, unsigned &NumCalls, bool &NotDuplicatable, bool &Convergent, const TargetTransformInfo &TTI, - AssumptionCache *AC) { + AssumptionCache *AC, unsigned BEInsns) { SmallPtrSet<const Value *, 32> EphValues; CodeMetrics::collectEphemeralValues(L, AC, EphValues); @@ -548,7 +569,7 @@ static unsigned ApproximateLoopSize(const Loop *L, unsigned &NumCalls, // that each loop has at least three instructions (likely a conditional // branch, a comparison feeding that branch, and some kind of loop increment // feeding that comparison instruction). - LoopSize = std::max(LoopSize, 3u); + LoopSize = std::max(LoopSize, BEInsns + 1); return LoopSize; } @@ -635,70 +656,38 @@ static void SetLoopAlreadyUnrolled(Loop *L) { L->setLoopID(NewLoopID); } -static bool canUnrollCompletely(Loop *L, unsigned Threshold, - unsigned PercentDynamicCostSavedThreshold, - unsigned DynamicCostSavingsDiscount, - uint64_t UnrolledCost, - uint64_t RolledDynamicCost) { - if (Threshold == NoThreshold) { - DEBUG(dbgs() << " Can fully unroll, because no threshold is set.\n"); - return true; - } - - if (UnrolledCost <= Threshold) { - DEBUG(dbgs() << " Can fully unroll, because unrolled cost: " - << UnrolledCost << "<" << Threshold << "\n"); - return true; - } - - assert(UnrolledCost && "UnrolledCost can't be 0 at this point."); - assert(RolledDynamicCost >= UnrolledCost && - "Cannot have a higher unrolled cost than a rolled cost!"); - - // Compute the percentage of the dynamic cost in the rolled form that is - // saved when unrolled. If unrolling dramatically reduces the estimated - // dynamic cost of the loop, we use a higher threshold to allow more - // unrolling. - unsigned PercentDynamicCostSaved = - (uint64_t)(RolledDynamicCost - UnrolledCost) * 100ull / RolledDynamicCost; - - if (PercentDynamicCostSaved >= PercentDynamicCostSavedThreshold && - (int64_t)UnrolledCost - (int64_t)DynamicCostSavingsDiscount <= - (int64_t)Threshold) { - DEBUG(dbgs() << " Can fully unroll, because unrolling will reduce the " - "expected dynamic cost by " - << PercentDynamicCostSaved << "% (threshold: " - << PercentDynamicCostSavedThreshold << "%)\n" - << " and the unrolled cost (" << UnrolledCost - << ") is less than the max threshold (" - << DynamicCostSavingsDiscount << ").\n"); - return true; - } +// Computes the boosting factor for complete unrolling. +// If fully unrolling the loop would save a lot of RolledDynamicCost, it would +// be beneficial to fully unroll the loop even if unrolledcost is large. We +// use (RolledDynamicCost / UnrolledCost) to model the unroll benefits to adjust +// the unroll threshold. +static unsigned getFullUnrollBoostingFactor(const EstimatedUnrollCost &Cost, + unsigned MaxPercentThresholdBoost) { + if (Cost.RolledDynamicCost >= UINT_MAX / 100) + return 100; + else if (Cost.UnrolledCost != 0) + // The boosting factor is RolledDynamicCost / UnrolledCost + return std::min(100 * Cost.RolledDynamicCost / Cost.UnrolledCost, + MaxPercentThresholdBoost); + else + return MaxPercentThresholdBoost; +} - DEBUG(dbgs() << " Too large to fully unroll:\n"); - DEBUG(dbgs() << " Threshold: " << Threshold << "\n"); - DEBUG(dbgs() << " Max threshold: " << DynamicCostSavingsDiscount << "\n"); - DEBUG(dbgs() << " Percent cost saved threshold: " - << PercentDynamicCostSavedThreshold << "%\n"); - DEBUG(dbgs() << " Unrolled cost: " << UnrolledCost << "\n"); - DEBUG(dbgs() << " Rolled dynamic cost: " << RolledDynamicCost << "\n"); - DEBUG(dbgs() << " Percent cost saved: " << PercentDynamicCostSaved - << "\n"); - return false; +// Returns loop size estimation for unrolled loop. +static uint64_t getUnrolledLoopSize( + unsigned LoopSize, + TargetTransformInfo::UnrollingPreferences &UP) { + assert(LoopSize >= UP.BEInsns && "LoopSize should not be less than BEInsns!"); + return (uint64_t)(LoopSize - UP.BEInsns) * UP.Count + UP.BEInsns; } // Returns true if unroll count was set explicitly. // Calculates unroll count and writes it to UP.Count. -static bool computeUnrollCount(Loop *L, const TargetTransformInfo &TTI, - DominatorTree &DT, LoopInfo *LI, - ScalarEvolution *SE, unsigned TripCount, - unsigned TripMultiple, unsigned LoopSize, - TargetTransformInfo::UnrollingPreferences &UP) { - // BEInsns represents number of instructions optimized when "back edge" - // becomes "fall through" in unrolled loop. - // For now we count a conditional branch on a backedge and a comparison - // feeding it. - unsigned BEInsns = 2; +static bool computeUnrollCount( + Loop *L, const TargetTransformInfo &TTI, DominatorTree &DT, LoopInfo *LI, + ScalarEvolution *SE, OptimizationRemarkEmitter *ORE, unsigned &TripCount, + unsigned MaxTripCount, unsigned &TripMultiple, unsigned LoopSize, + TargetTransformInfo::UnrollingPreferences &UP, bool &UseUpperBound) { // Check for explicit Count. // 1st priority is unroll count set by "unroll-count" option. bool UserUnrollCount = UnrollCount.getNumOccurrences() > 0; @@ -706,8 +695,7 @@ static bool computeUnrollCount(Loop *L, const TargetTransformInfo &TTI, UP.Count = UnrollCount; UP.AllowExpensiveTripCount = true; UP.Force = true; - if (UP.AllowRemainder && - (LoopSize - BEInsns) * UP.Count + BEInsns < UP.Threshold) + if (UP.AllowRemainder && getUnrolledLoopSize(LoopSize, UP) < UP.Threshold) return true; } @@ -719,13 +707,13 @@ static bool computeUnrollCount(Loop *L, const TargetTransformInfo &TTI, UP.AllowExpensiveTripCount = true; UP.Force = true; if (UP.AllowRemainder && - (LoopSize - BEInsns) * UP.Count + BEInsns < PragmaUnrollThreshold) + getUnrolledLoopSize(LoopSize, UP) < PragmaUnrollThreshold) return true; } bool PragmaFullUnroll = HasUnrollFullPragma(L); if (PragmaFullUnroll && TripCount != 0) { UP.Count = TripCount; - if ((LoopSize - BEInsns) * UP.Count + BEInsns < PragmaUnrollThreshold) + if (getUnrolledLoopSize(LoopSize, UP) < PragmaUnrollThreshold) return false; } @@ -733,11 +721,6 @@ static bool computeUnrollCount(Loop *L, const TargetTransformInfo &TTI, bool ExplicitUnroll = PragmaCount > 0 || PragmaFullUnroll || PragmaEnableUnroll || UserUnrollCount; - uint64_t UnrolledSize; - DebugLoc LoopLoc = L->getStartLoc(); - Function *F = L->getHeader()->getParent(); - LLVMContext &Ctx = F->getContext(); - if (ExplicitUnroll && TripCount != 0) { // If the loop has an unrolling pragma, we want to be more aggressive with // unrolling limits. Set thresholds to at least the PragmaThreshold value @@ -748,38 +731,48 @@ static bool computeUnrollCount(Loop *L, const TargetTransformInfo &TTI, } // 3rd priority is full unroll count. - // Full unroll make sense only when TripCount could be staticaly calculated. + // Full unroll makes sense only when TripCount or its upper bound could be + // statically calculated. // Also we need to check if we exceed FullUnrollMaxCount. - if (TripCount && TripCount <= UP.FullUnrollMaxCount) { + // If using the upper bound to unroll, TripMultiple should be set to 1 because + // we do not know when loop may exit. + // MaxTripCount and ExactTripCount cannot both be non zero since we only + // compute the former when the latter is zero. + unsigned ExactTripCount = TripCount; + assert((ExactTripCount == 0 || MaxTripCount == 0) && + "ExtractTripCound and MaxTripCount cannot both be non zero."); + unsigned FullUnrollTripCount = ExactTripCount ? ExactTripCount : MaxTripCount; + UP.Count = FullUnrollTripCount; + if (FullUnrollTripCount && FullUnrollTripCount <= UP.FullUnrollMaxCount) { // When computing the unrolled size, note that BEInsns are not replicated // like the rest of the loop body. - UnrolledSize = (uint64_t)(LoopSize - BEInsns) * TripCount + BEInsns; - if (canUnrollCompletely(L, UP.Threshold, 100, UP.DynamicCostSavingsDiscount, - UnrolledSize, UnrolledSize)) { - UP.Count = TripCount; + if (getUnrolledLoopSize(LoopSize, UP) < UP.Threshold) { + UseUpperBound = (MaxTripCount == FullUnrollTripCount); + TripCount = FullUnrollTripCount; + TripMultiple = UP.UpperBound ? 1 : TripMultiple; return ExplicitUnroll; } else { // The loop isn't that small, but we still can fully unroll it if that // helps to remove a significant number of instructions. // To check that, run additional analysis on the loop. if (Optional<EstimatedUnrollCost> Cost = analyzeLoopUnrollCost( - L, TripCount, DT, *SE, TTI, - UP.Threshold + UP.DynamicCostSavingsDiscount)) - if (canUnrollCompletely(L, UP.Threshold, - UP.PercentDynamicCostSavedThreshold, - UP.DynamicCostSavingsDiscount, - Cost->UnrolledCost, Cost->RolledDynamicCost)) { - UP.Count = TripCount; + L, FullUnrollTripCount, DT, *SE, TTI, + UP.Threshold * UP.MaxPercentThresholdBoost / 100)) { + unsigned Boost = + getFullUnrollBoostingFactor(*Cost, UP.MaxPercentThresholdBoost); + if (Cost->UnrolledCost < UP.Threshold * Boost / 100) { + UseUpperBound = (MaxTripCount == FullUnrollTripCount); + TripCount = FullUnrollTripCount; + TripMultiple = UP.UpperBound ? 1 : TripMultiple; return ExplicitUnroll; } + } } } // 4rd priority is partial unrolling. // Try partial unroll only when TripCount could be staticaly calculated. if (TripCount) { - if (UP.Count == 0) - UP.Count = TripCount; UP.Partial |= ExplicitUnroll; if (!UP.Partial) { DEBUG(dbgs() << " will not try to unroll partially because " @@ -787,12 +780,14 @@ static bool computeUnrollCount(Loop *L, const TargetTransformInfo &TTI, UP.Count = 0; return false; } + if (UP.Count == 0) + UP.Count = TripCount; if (UP.PartialThreshold != NoThreshold) { // Reduce unroll count to be modulo of TripCount for partial unrolling. - UnrolledSize = (uint64_t)(LoopSize - BEInsns) * UP.Count + BEInsns; - if (UnrolledSize > UP.PartialThreshold) - UP.Count = (std::max(UP.PartialThreshold, 3u) - BEInsns) / - (LoopSize - BEInsns); + if (getUnrolledLoopSize(LoopSize, UP) > UP.PartialThreshold) + UP.Count = + (std::max(UP.PartialThreshold, UP.BEInsns + 1) - UP.BEInsns) / + (LoopSize - UP.BEInsns); if (UP.Count > UP.MaxCount) UP.Count = UP.MaxCount; while (UP.Count != 0 && TripCount % UP.Count != 0) @@ -802,19 +797,18 @@ static bool computeUnrollCount(Loop *L, const TargetTransformInfo &TTI, // largest power-of-two factor that satisfies the threshold limit. // As we'll create fixup loop, do the type of unrolling only if // remainder loop is allowed. - UP.Count = DefaultUnrollRuntimeCount; - UnrolledSize = (LoopSize - BEInsns) * UP.Count + BEInsns; - while (UP.Count != 0 && UnrolledSize > UP.PartialThreshold) { + UP.Count = UP.DefaultUnrollRuntimeCount; + while (UP.Count != 0 && + getUnrolledLoopSize(LoopSize, UP) > UP.PartialThreshold) UP.Count >>= 1; - UnrolledSize = (LoopSize - BEInsns) * UP.Count + BEInsns; - } } if (UP.Count < 2) { if (PragmaEnableUnroll) - emitOptimizationRemarkMissed( - Ctx, DEBUG_TYPE, *F, LoopLoc, - "Unable to unroll loop as directed by unroll(enable) pragma " - "because unrolled size is too large."); + ORE->emit( + OptimizationRemarkMissed(DEBUG_TYPE, "UnrollAsDirectedTooLarge", + L->getStartLoc(), L->getHeader()) + << "Unable to unroll loop as directed by unroll(enable) pragma " + "because unrolled size is too large."); UP.Count = 0; } } else { @@ -822,26 +816,48 @@ static bool computeUnrollCount(Loop *L, const TargetTransformInfo &TTI, } if ((PragmaFullUnroll || PragmaEnableUnroll) && TripCount && UP.Count != TripCount) - emitOptimizationRemarkMissed( - Ctx, DEBUG_TYPE, *F, LoopLoc, - "Unable to fully unroll loop as directed by unroll pragma because " - "unrolled size is too large."); + ORE->emit( + OptimizationRemarkMissed(DEBUG_TYPE, "FullUnrollAsDirectedTooLarge", + L->getStartLoc(), L->getHeader()) + << "Unable to fully unroll loop as directed by unroll pragma because " + "unrolled size is too large."); return ExplicitUnroll; } assert(TripCount == 0 && "All cases when TripCount is constant should be covered here."); if (PragmaFullUnroll) - emitOptimizationRemarkMissed( - Ctx, DEBUG_TYPE, *F, LoopLoc, - "Unable to fully unroll loop as directed by unroll(full) pragma " - "because loop has a runtime trip count."); + ORE->emit( + OptimizationRemarkMissed(DEBUG_TYPE, + "CantFullUnrollAsDirectedRuntimeTripCount", + L->getStartLoc(), L->getHeader()) + << "Unable to fully unroll loop as directed by unroll(full) pragma " + "because loop has a runtime trip count."); + + // 5th priority is loop peeling + computePeelCount(L, LoopSize, UP); + if (UP.PeelCount) { + UP.Runtime = false; + UP.Count = 1; + return ExplicitUnroll; + } - // 5th priority is runtime unrolling. + // 6th priority is runtime unrolling. // Don't unroll a runtime trip count loop when it is disabled. if (HasRuntimeUnrollDisablePragma(L)) { UP.Count = 0; return false; } + + // Check if the runtime trip count is too small when profile is available. + if (L->getHeader()->getParent()->getEntryCount()) { + if (auto ProfileTripCount = getLoopEstimatedTripCount(L)) { + if (*ProfileTripCount < FlatLoopTripCountThreshold) + return false; + else + UP.AllowExpensiveTripCount = true; + } + } + // Reduce count based on the type of unrolling and the threshold values. UP.Runtime |= PragmaEnableUnroll || PragmaCount > 0 || UserUnrollCount; if (!UP.Runtime) { @@ -851,15 +867,13 @@ static bool computeUnrollCount(Loop *L, const TargetTransformInfo &TTI, return false; } if (UP.Count == 0) - UP.Count = DefaultUnrollRuntimeCount; - UnrolledSize = (LoopSize - BEInsns) * UP.Count + BEInsns; + UP.Count = UP.DefaultUnrollRuntimeCount; // Reduce unroll count to be the largest power-of-two factor of // the original count which satisfies the threshold limit. - while (UP.Count != 0 && UnrolledSize > UP.PartialThreshold) { + while (UP.Count != 0 && + getUnrolledLoopSize(LoopSize, UP) > UP.PartialThreshold) UP.Count >>= 1; - UnrolledSize = (LoopSize - BEInsns) * UP.Count + BEInsns; - } #ifndef NDEBUG unsigned OrigCount = UP.Count; @@ -874,16 +888,19 @@ static bool computeUnrollCount(Loop *L, const TargetTransformInfo &TTI, "multiple, " << TripMultiple << ". Reducing unroll count from " << OrigCount << " to " << UP.Count << ".\n"); + using namespace ore; if (PragmaCount > 0 && !UP.AllowRemainder) - emitOptimizationRemarkMissed( - Ctx, DEBUG_TYPE, *F, LoopLoc, - Twine("Unable to unroll loop the number of times directed by " - "unroll_count pragma because remainder loop is restricted " - "(that could architecture specific or because the loop " - "contains a convergent instruction) and so must have an unroll " - "count that divides the loop trip multiple of ") + - Twine(TripMultiple) + ". Unrolling instead " + Twine(UP.Count) + - " time(s)."); + ORE->emit( + OptimizationRemarkMissed(DEBUG_TYPE, + "DifferentUnrollCountFromDirected", + L->getStartLoc(), L->getHeader()) + << "Unable to unroll loop the number of times directed by " + "unroll_count pragma because remainder loop is restricted " + "(that could architecture specific or because the loop " + "contains a convergent instruction) and so must have an unroll " + "count that divides the loop trip multiple of " + << NV("TripMultiple", TripMultiple) << ". Unrolling instead " + << NV("UnrollCount", UP.Count) << " time(s)."); } if (UP.Count > UP.MaxCount) @@ -896,22 +913,34 @@ static bool computeUnrollCount(Loop *L, const TargetTransformInfo &TTI, static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution *SE, const TargetTransformInfo &TTI, - AssumptionCache &AC, bool PreserveLCSSA, + AssumptionCache &AC, OptimizationRemarkEmitter &ORE, + bool PreserveLCSSA, Optional<unsigned> ProvidedCount, Optional<unsigned> ProvidedThreshold, Optional<bool> ProvidedAllowPartial, - Optional<bool> ProvidedRuntime) { + Optional<bool> ProvidedRuntime, + Optional<bool> ProvidedUpperBound) { DEBUG(dbgs() << "Loop Unroll: F[" << L->getHeader()->getParent()->getName() << "] Loop %" << L->getHeader()->getName() << "\n"); - if (HasUnrollDisablePragma(L)) { + if (HasUnrollDisablePragma(L)) + return false; + if (!L->isLoopSimplifyForm()) { + DEBUG( + dbgs() << " Not unrolling loop which is not in loop-simplify form.\n"); return false; } unsigned NumInlineCandidates; bool NotDuplicatable; bool Convergent; + TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( + L, TTI, ProvidedThreshold, ProvidedCount, ProvidedAllowPartial, + ProvidedRuntime, ProvidedUpperBound); + // Exit early if unrolling is disabled. + if (UP.Threshold == 0 && (!UP.Partial || UP.PartialThreshold == 0)) + return false; unsigned LoopSize = ApproximateLoopSize( - L, NumInlineCandidates, NotDuplicatable, Convergent, TTI, &AC); + L, NumInlineCandidates, NotDuplicatable, Convergent, TTI, &AC, UP.BEInsns); DEBUG(dbgs() << " Loop Size = " << LoopSize << "\n"); if (NotDuplicatable) { DEBUG(dbgs() << " Not unrolling loop which contains non-duplicatable" @@ -922,14 +951,10 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n"); return false; } - if (!L->isLoopSimplifyForm()) { - DEBUG( - dbgs() << " Not unrolling loop which is not in loop-simplify form.\n"); - return false; - } // Find trip count and trip multiple if count is not available unsigned TripCount = 0; + unsigned MaxTripCount = 0; unsigned TripMultiple = 1; // If there are multiple exiting blocks but one of them is the latch, use the // latch for the trip count estimation. Otherwise insist on a single exiting @@ -942,10 +967,6 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock); } - TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( - L, TTI, ProvidedThreshold, ProvidedCount, ProvidedAllowPartial, - ProvidedRuntime); - // If the loop contains a convergent operation, the prelude we'd add // to do the first few instructions before we hit the unrolled loop // is unsafe -- it adds a control-flow dependency to the convergent @@ -961,8 +982,31 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, if (Convergent) UP.AllowRemainder = false; - bool IsCountSetExplicitly = computeUnrollCount(L, TTI, DT, LI, SE, TripCount, - TripMultiple, LoopSize, UP); + // Try to find the trip count upper bound if we cannot find the exact trip + // count. + bool MaxOrZero = false; + if (!TripCount) { + MaxTripCount = SE->getSmallConstantMaxTripCount(L); + MaxOrZero = SE->isBackedgeTakenCountMaxOrZero(L); + // We can unroll by the upper bound amount if it's generally allowed or if + // we know that the loop is executed either the upper bound or zero times. + // (MaxOrZero unrolling keeps only the first loop test, so the number of + // loop tests remains the same compared to the non-unrolled version, whereas + // the generic upper bound unrolling keeps all but the last loop test so the + // number of loop tests goes up which may end up being worse on targets with + // constriained branch predictor resources so is controlled by an option.) + // In addition we only unroll small upper bounds. + if (!(UP.UpperBound || MaxOrZero) || MaxTripCount > UnrollMaxUpperBound) { + MaxTripCount = 0; + } + } + + // computeUnrollCount() decides whether it is beneficial to use upper bound to + // fully unroll the loop. + bool UseUpperBound = false; + bool IsCountSetExplicitly = + computeUnrollCount(L, TTI, DT, LI, SE, &ORE, TripCount, MaxTripCount, + TripMultiple, LoopSize, UP, UseUpperBound); if (!UP.Count) return false; // Unroll factor (Count) must be less or equal to TripCount. @@ -971,14 +1015,18 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, // Unroll the loop. if (!UnrollLoop(L, UP.Count, TripCount, UP.Force, UP.Runtime, - UP.AllowExpensiveTripCount, TripMultiple, LI, SE, &DT, &AC, + UP.AllowExpensiveTripCount, UseUpperBound, MaxOrZero, + TripMultiple, UP.PeelCount, LI, SE, &DT, &AC, &ORE, PreserveLCSSA)) return false; // If loop has an unroll count pragma or unrolled by explicitly set count // mark loop as unrolled to prevent unrolling beyond that requested. - if (IsCountSetExplicitly) + // If the loop was peeled, we already "used up" the profile information + // we had, so we don't want to unroll or peel again. + if (IsCountSetExplicitly || UP.PeelCount) SetLoopAlreadyUnrolled(L); + return true; } @@ -988,10 +1036,11 @@ public: static char ID; // Pass ID, replacement for typeid LoopUnroll(Optional<unsigned> Threshold = None, Optional<unsigned> Count = None, - Optional<bool> AllowPartial = None, Optional<bool> Runtime = None) + Optional<bool> AllowPartial = None, Optional<bool> Runtime = None, + Optional<bool> UpperBound = None) : LoopPass(ID), ProvidedCount(std::move(Count)), ProvidedThreshold(Threshold), ProvidedAllowPartial(AllowPartial), - ProvidedRuntime(Runtime) { + ProvidedRuntime(Runtime), ProvidedUpperBound(UpperBound) { initializeLoopUnrollPass(*PassRegistry::getPassRegistry()); } @@ -999,6 +1048,7 @@ public: Optional<unsigned> ProvidedThreshold; Optional<bool> ProvidedAllowPartial; Optional<bool> ProvidedRuntime; + Optional<bool> ProvidedUpperBound; bool runOnLoop(Loop *L, LPPassManager &) override { if (skipLoop(L)) @@ -1012,11 +1062,16 @@ public: const TargetTransformInfo &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + // For the old PM, we can't use OptimizationRemarkEmitter as an analysis + // pass. Function analyses need to be preserved across loop transformations + // but ORE cannot be preserved (see comment before the pass definition). + OptimizationRemarkEmitter ORE(&F); bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); - return tryToUnrollLoop(L, DT, LI, SE, TTI, AC, PreserveLCSSA, ProvidedCount, - ProvidedThreshold, ProvidedAllowPartial, - ProvidedRuntime); + return tryToUnrollLoop(L, DT, LI, SE, TTI, AC, ORE, PreserveLCSSA, + ProvidedCount, ProvidedThreshold, + ProvidedAllowPartial, ProvidedRuntime, + ProvidedUpperBound); } /// This transformation requires natural loop information & requires that @@ -1040,7 +1095,7 @@ INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(LoopUnroll, "loop-unroll", "Unroll loops", false, false) Pass *llvm::createLoopUnrollPass(int Threshold, int Count, int AllowPartial, - int Runtime) { + int Runtime, int UpperBound) { // TODO: It would make more sense for this function to take the optionals // directly, but that's dangerous since it would silently break out of tree // callers. @@ -1048,9 +1103,33 @@ Pass *llvm::createLoopUnrollPass(int Threshold, int Count, int AllowPartial, Count == -1 ? None : Optional<unsigned>(Count), AllowPartial == -1 ? None : Optional<bool>(AllowPartial), - Runtime == -1 ? None : Optional<bool>(Runtime)); + Runtime == -1 ? None : Optional<bool>(Runtime), + UpperBound == -1 ? None : Optional<bool>(UpperBound)); } Pass *llvm::createSimpleLoopUnrollPass() { - return llvm::createLoopUnrollPass(-1, -1, 0, 0); + return llvm::createLoopUnrollPass(-1, -1, 0, 0, 0); +} + +PreservedAnalyses LoopUnrollPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + const auto &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); + Function *F = L.getHeader()->getParent(); + + auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F); + // FIXME: This should probably be optional rather than required. + if (!ORE) + report_fatal_error("LoopUnrollPass: OptimizationRemarkEmitterAnalysis not " + "cached at a higher level"); + + bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, &AR.SE, AR.TTI, AR.AC, *ORE, + /*PreserveLCSSA*/ true, ProvidedCount, + ProvidedThreshold, ProvidedAllowPartial, + ProvidedRuntime, ProvidedUpperBound); + + if (!Changed) + return PreservedAnalyses::all(); + return getLoopPassPreservedAnalyses(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp index 71980e8..76fe918 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -210,7 +210,7 @@ namespace { bool runOnLoop(Loop *L, LPPassManager &LPM) override; bool processCurrentLoop(); - + bool isUnreachableDueToPreviousUnswitching(BasicBlock *); /// This transformation requires natural loop information & requires that /// loop preheaders be inserted into the CFG. /// @@ -483,6 +483,35 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { return Changed; } +// Return true if the BasicBlock BB is unreachable from the loop header. +// Return false, otherwise. +bool LoopUnswitch::isUnreachableDueToPreviousUnswitching(BasicBlock *BB) { + auto *Node = DT->getNode(BB)->getIDom(); + BasicBlock *DomBB = Node->getBlock(); + while (currentLoop->contains(DomBB)) { + BranchInst *BInst = dyn_cast<BranchInst>(DomBB->getTerminator()); + + Node = DT->getNode(DomBB)->getIDom(); + DomBB = Node->getBlock(); + + if (!BInst || !BInst->isConditional()) + continue; + + Value *Cond = BInst->getCondition(); + if (!isa<ConstantInt>(Cond)) + continue; + + BasicBlock *UnreachableSucc = + Cond == ConstantInt::getTrue(Cond->getContext()) + ? BInst->getSuccessor(1) + : BInst->getSuccessor(0); + + if (DT->dominates(UnreachableSucc, BB)) + return true; + } + return false; +} + /// Do actual work and unswitch loop if possible and profitable. bool LoopUnswitch::processCurrentLoop() { bool Changed = false; @@ -593,6 +622,12 @@ bool LoopUnswitch::processCurrentLoop() { continue; if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { + // Some branches may be rendered unreachable because of previous + // unswitching. + // Unswitch only those branches that are reachable. + if (isUnreachableDueToPreviousUnswitching(*I)) + continue; + // If this isn't branching on an invariant condition, we can't unswitch // it. if (BI->isConditional()) { @@ -742,42 +777,6 @@ static Loop *CloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM, return &New; } -static void copyMetadata(Instruction *DstInst, const Instruction *SrcInst, - bool Swapped) { - if (!SrcInst || !SrcInst->hasMetadata()) - return; - - SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; - SrcInst->getAllMetadata(MDs); - for (auto &MD : MDs) { - switch (MD.first) { - default: - break; - case LLVMContext::MD_prof: - if (Swapped && MD.second->getNumOperands() == 3 && - isa<MDString>(MD.second->getOperand(0))) { - MDString *MDName = cast<MDString>(MD.second->getOperand(0)); - if (MDName->getString() == "branch_weights") { - auto *ValT = cast_or_null<ConstantAsMetadata>( - MD.second->getOperand(1))->getValue(); - auto *ValF = cast_or_null<ConstantAsMetadata>( - MD.second->getOperand(2))->getValue(); - assert(ValT && ValF && "Invalid Operands of branch_weights"); - auto NewMD = - MDBuilder(DstInst->getParent()->getContext()) - .createBranchWeights(cast<ConstantInt>(ValF)->getZExtValue(), - cast<ConstantInt>(ValT)->getZExtValue()); - MD.second = NewMD; - } - } - // fallthrough. - case LLVMContext::MD_make_implicit: - case LLVMContext::MD_dbg: - DstInst->setMetadata(MD.first, MD.second); - } - } -} - /// Emit a conditional branch on two values if LIC == Val, branch to TrueDst, /// otherwise branch to FalseDest. Insert the code immediately before InsertPt. void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, @@ -799,8 +798,10 @@ void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, } // Insert the new branch. - BranchInst *BI = BranchInst::Create(TrueDest, FalseDest, BranchVal, InsertPt); - copyMetadata(BI, TI, Swapped); + BranchInst *BI = + IRBuilder<>(InsertPt).CreateCondBr(BranchVal, TrueDest, FalseDest, TI); + if (Swapped) + BI->swapProfMetadata(); // If either edge is critical, split it. This helps preserve LoopSimplify // form for enclosing loops. @@ -1078,10 +1079,6 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, F->getBasicBlockList(), NewBlocks[0]->getIterator(), F->end()); - // FIXME: We could register any cloned assumptions instead of clearing the - // whole function's cache. - AC->clear(); - // Now we create the new Loop object for the versioned loop. Loop *NewLoop = CloneLoop(L, L->getParentLoop(), VMap, LI, LPM); @@ -1131,10 +1128,15 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, } // Rewrite the code to refer to itself. - for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) - for (Instruction &I : *NewBlocks[i]) + for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) { + for (Instruction &I : *NewBlocks[i]) { RemapInstruction(&I, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::assume) + AC->registerAssumption(II); + } + } // Rewrite the original preheader to select between versions of the loop. BranchInst *OldBR = cast<BranchInst>(loopPreheader->getTerminator()); @@ -1380,8 +1382,8 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { Pred->getInstList().splice(BI->getIterator(), Succ->getInstList(), Succ->begin(), Succ->end()); LPM->deleteSimpleAnalysisValue(BI, L); - BI->eraseFromParent(); RemoveFromWorklist(BI, Worklist); + BI->eraseFromParent(); // Remove Succ from the loop tree. LI->removeBlock(Succ); diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp index 0ccf0af..c23d891 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -92,8 +92,7 @@ #include "llvm/Transforms/Utils/ValueMapper.h" #define DEBUG_TYPE "loop-versioning-licm" -static const char* LICMVersioningMetaData = - "llvm.loop.licm_versioning.disable"; +static const char *LICMVersioningMetaData = "llvm.loop.licm_versioning.disable"; using namespace llvm; @@ -158,34 +157,48 @@ struct LoopVersioningLICM : public LoopPass { AU.addRequired<LoopInfoWrapperPass>(); AU.addRequiredID(LoopSimplifyID); AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addPreserved<AAResultsWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } LoopVersioningLICM() - : LoopPass(ID), AA(nullptr), SE(nullptr), LI(nullptr), DT(nullptr), - TLI(nullptr), LAA(nullptr), LAI(nullptr), Changed(false), - Preheader(nullptr), CurLoop(nullptr), CurAST(nullptr), - LoopDepthThreshold(LVLoopDepthThreshold), + : LoopPass(ID), AA(nullptr), SE(nullptr), LAA(nullptr), LAI(nullptr), + CurLoop(nullptr), LoopDepthThreshold(LVLoopDepthThreshold), InvariantThreshold(LVInvarThreshold), LoadAndStoreCounter(0), InvariantCounter(0), IsReadOnlyLoop(true) { initializeLoopVersioningLICMPass(*PassRegistry::getPassRegistry()); } + StringRef getPassName() const override { return "Loop Versioning for LICM"; } - AliasAnalysis *AA; // Current AliasAnalysis information - ScalarEvolution *SE; // Current ScalarEvolution - LoopInfo *LI; // Current LoopInfo - DominatorTree *DT; // Dominator Tree for the current Loop. - TargetLibraryInfo *TLI; // TargetLibraryInfo for constant folding. - LoopAccessLegacyAnalysis *LAA; // Current LoopAccessAnalysis - const LoopAccessInfo *LAI; // Current Loop's LoopAccessInfo + void reset() { + AA = nullptr; + SE = nullptr; + LAA = nullptr; + CurLoop = nullptr; + LoadAndStoreCounter = 0; + InvariantCounter = 0; + IsReadOnlyLoop = true; + CurAST.reset(); + } + + class AutoResetter { + public: + AutoResetter(LoopVersioningLICM &LVLICM) : LVLICM(LVLICM) {} + ~AutoResetter() { LVLICM.reset(); } + + private: + LoopVersioningLICM &LVLICM; + }; - bool Changed; // Set to true when we change anything. - BasicBlock *Preheader; // The preheader block of the current loop. - Loop *CurLoop; // The current loop we are working on. - AliasSetTracker *CurAST; // AliasSet information for the current loop. - ValueToValueMap Strides; +private: + AliasAnalysis *AA; // Current AliasAnalysis information + ScalarEvolution *SE; // Current ScalarEvolution + LoopAccessLegacyAnalysis *LAA; // Current LoopAccessAnalysis + const LoopAccessInfo *LAI; // Current Loop's LoopAccessInfo + + Loop *CurLoop; // The current loop we are working on. + std::unique_ptr<AliasSetTracker> + CurAST; // AliasSet information for the current loop. unsigned LoopDepthThreshold; // Maximum loop nest threshold float InvariantThreshold; // Minimum invariant threshold @@ -200,15 +213,15 @@ struct LoopVersioningLICM : public LoopPass { bool isLoopAlreadyVisited(); void setNoAliasToLoop(Loop *); bool instructionSafeForVersioning(Instruction *); - const char *getPassName() const override { return "Loop Versioning"; } }; } /// \brief Check loop structure and confirms it's good for LoopVersioningLICM. bool LoopVersioningLICM::legalLoopStructure() { - // Loop must have a preheader, if not return false. - if (!CurLoop->getLoopPreheader()) { - DEBUG(dbgs() << " loop preheader is missing\n"); + // Loop must be in loop simplify form. + if (!CurLoop->isLoopSimplifyForm()) { + DEBUG( + dbgs() << " loop is not in loop-simplify form.\n"); return false; } // Loop should be innermost loop, if not return false. @@ -244,11 +257,6 @@ bool LoopVersioningLICM::legalLoopStructure() { DEBUG(dbgs() << " loop depth is more then threshold\n"); return false; } - // Loop should have a dedicated exit block, if not return false. - if (!CurLoop->hasDedicatedExits()) { - DEBUG(dbgs() << " loop does not has dedicated exit blocks\n"); - return false; - } // We need to be able to compute the loop trip count in order // to generate the bound checks. const SCEV *ExitCount = SE->getBackedgeTakenCount(CurLoop); @@ -505,29 +513,30 @@ void LoopVersioningLICM::setNoAliasToLoop(Loop *VerLoop) { } bool LoopVersioningLICM::runOnLoop(Loop *L, LPPassManager &LPM) { + // This will automatically release all resources hold by the current + // LoopVersioningLICM object. + AutoResetter Resetter(*this); + if (skipLoop(L)) return false; - Changed = false; // Get Analysis information. - LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); LAI = nullptr; // Set Current Loop CurLoop = L; - // Get the preheader block. - Preheader = L->getLoopPreheader(); - // Initial allocation - CurAST = new AliasSetTracker(*AA); + CurAST.reset(new AliasSetTracker(*AA)); // Loop over the body of this loop, construct AST. + LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); for (auto *Block : L->getBlocks()) { if (LI->getLoopFor(Block) == L) // Ignore blocks in subloop. CurAST->add(*Block); // Incorporate the specified basic block } + + bool Changed = false; + // Check feasiblity of LoopVersioningLICM. // If versioning found to be feasible and beneficial then proceed // else simply return, by cleaning up memory. @@ -535,6 +544,7 @@ bool LoopVersioningLICM::runOnLoop(Loop *L, LPPassManager &LPM) { // Do loop versioning. // Create memcheck for memory accessed inside loop. // Clone original loop, and set blocks properly. + DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); LoopVersioning LVer(*LAI, CurLoop, LI, DT, SE, true); LVer.versionLoop(); // Set Loop Versioning metaData for original loop. @@ -548,8 +558,6 @@ bool LoopVersioningLICM::runOnLoop(Loop *L, LPPassManager &LPM) { setNoAliasToLoop(LVer.getVersionedLoop()); Changed = true; } - // Delete allocated memory. - delete CurAST; return Changed; } @@ -564,7 +572,6 @@ INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(LoopVersioningLICM, "loop-versioning-licm", "Loop Versioning For LICM", false, false) diff --git a/contrib/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/contrib/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp index 79f0db1..52975ef 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -83,9 +83,8 @@ static bool handleSwitchExpect(SwitchInst &SI) { return true; } -static bool handleBranchExpect(BranchInst &BI) { - if (BI.isUnconditional()) - return false; +// Handle both BranchInst and SelectInst. +template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) { // Handle non-optimized IR code like: // %expval = call i64 @llvm.expect.i64(i64 %conv1, i64 1) @@ -98,9 +97,9 @@ static bool handleBranchExpect(BranchInst &BI) { CallInst *CI; - ICmpInst *CmpI = dyn_cast<ICmpInst>(BI.getCondition()); + ICmpInst *CmpI = dyn_cast<ICmpInst>(BSI.getCondition()); if (!CmpI) { - CI = dyn_cast<CallInst>(BI.getCondition()); + CI = dyn_cast<CallInst>(BSI.getCondition()); } else { if (CmpI->getPredicate() != CmpInst::ICMP_NE) return false; @@ -129,15 +128,22 @@ static bool handleBranchExpect(BranchInst &BI) { else Node = MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight); - BI.setMetadata(LLVMContext::MD_prof, Node); + BSI.setMetadata(LLVMContext::MD_prof, Node); if (CmpI) CmpI->setOperand(0, ArgValue); else - BI.setCondition(ArgValue); + BSI.setCondition(ArgValue); return true; } +static bool handleBranchExpect(BranchInst &BI) { + if (BI.isUnconditional()) + return false; + + return handleBrSelExpect<BranchInst>(BI); +} + static bool lowerExpectIntrinsic(Function &F) { bool Changed = false; @@ -151,11 +157,19 @@ static bool lowerExpectIntrinsic(Function &F) { ExpectIntrinsicsHandled++; } - // Remove llvm.expect intrinsics. - for (BasicBlock::iterator BI = BB.begin(), BE = BB.end(); BI != BE;) { - CallInst *CI = dyn_cast<CallInst>(BI++); - if (!CI) + // Remove llvm.expect intrinsics. Iterate backwards in order + // to process select instructions before the intrinsic gets + // removed. + for (auto BI = BB.rbegin(), BE = BB.rend(); BI != BE;) { + Instruction *Inst = &*BI++; + CallInst *CI = dyn_cast<CallInst>(Inst); + if (!CI) { + if (SelectInst *SI = dyn_cast<SelectInst>(Inst)) { + if (handleBrSelExpect(*SI)) + ExpectIntrinsicsHandled++; + } continue; + } Function *Fn = CI->getCalledFunction(); if (Fn && Fn->getIntrinsicID() == Intrinsic::expect) { diff --git a/contrib/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp b/contrib/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp index 5749100..4f41371 100644 --- a/contrib/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp @@ -13,7 +13,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LowerGuardIntrinsic.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -24,6 +24,7 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; @@ -34,10 +35,11 @@ static cl::opt<uint32_t> PredicatePassBranchWeight( "reciprocal of this value (default = 1 << 20)")); namespace { -struct LowerGuardIntrinsic : public FunctionPass { +struct LowerGuardIntrinsicLegacyPass : public FunctionPass { static char ID; - LowerGuardIntrinsic() : FunctionPass(ID) { - initializeLowerGuardIntrinsicPass(*PassRegistry::getPassRegistry()); + LowerGuardIntrinsicLegacyPass() : FunctionPass(ID) { + initializeLowerGuardIntrinsicLegacyPassPass( + *PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override; @@ -83,7 +85,7 @@ static void MakeGuardControlFlowExplicit(Function *DeoptIntrinsic, DeoptBlockTerm->eraseFromParent(); } -bool LowerGuardIntrinsic::runOnFunction(Function &F) { +static bool lowerGuardIntrinsic(Function &F) { // Check if we can cheaply rule out the possibility of not having any work to // do. auto *GuardDecl = F.getParent()->getFunction( @@ -113,11 +115,23 @@ bool LowerGuardIntrinsic::runOnFunction(Function &F) { return true; } -char LowerGuardIntrinsic::ID = 0; -INITIALIZE_PASS(LowerGuardIntrinsic, "lower-guard-intrinsic", +bool LowerGuardIntrinsicLegacyPass::runOnFunction(Function &F) { + return lowerGuardIntrinsic(F); +} + +char LowerGuardIntrinsicLegacyPass::ID = 0; +INITIALIZE_PASS(LowerGuardIntrinsicLegacyPass, "lower-guard-intrinsic", "Lower the guard intrinsic to normal control flow", false, false) Pass *llvm::createLowerGuardIntrinsicPass() { - return new LowerGuardIntrinsic(); + return new LowerGuardIntrinsicLegacyPass(); +} + +PreservedAnalyses LowerGuardIntrinsicPass::run(Function &F, + FunctionAnalysisManager &AM) { + if (lowerGuardIntrinsic(F)) + return PreservedAnalyses::none(); + + return PreservedAnalyses::all(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/contrib/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index d64c658..1b59014 100644 --- a/contrib/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -52,7 +52,7 @@ static int64_t GetOffsetFromIndex(const GEPOperator *GEP, unsigned Idx, if (OpC->isZero()) continue; // No offset. // Handle struct indices, which add their field offset to the pointer. - if (StructType *STy = dyn_cast<StructType>(*GTI)) { + if (StructType *STy = GTI.getStructTypeOrNull()) { Offset += DL.getStructLayout(STy)->getElementOffset(OpC->getZExtValue()); continue; } @@ -489,7 +489,8 @@ static unsigned findCommonAlignment(const DataLayout &DL, const StoreInst *SI, // It will lift the store and its argument + that anything that // may alias with these. // The method returns true if it was successful. -static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P) { +static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P, + const LoadInst *LI) { // If the store alias this position, early bail out. MemoryLocation StoreLoc = MemoryLocation::get(SI); if (AA.getModRefInfo(P, StoreLoc) != MRI_NoModRef) @@ -506,12 +507,13 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P) { SmallVector<Instruction*, 8> ToLift; // Memory locations of lifted instructions. - SmallVector<MemoryLocation, 8> MemLocs; - MemLocs.push_back(StoreLoc); + SmallVector<MemoryLocation, 8> MemLocs{StoreLoc}; // Lifted callsites. SmallVector<ImmutableCallSite, 8> CallSites; + const MemoryLocation LoadLoc = MemoryLocation::get(LI); + for (auto I = --SI->getIterator(), E = P->getIterator(); I != E; --I) { auto *C = &*I; @@ -521,23 +523,25 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P) { if (Args.erase(C)) NeedLift = true; else if (MayAlias) { - NeedLift = std::any_of(MemLocs.begin(), MemLocs.end(), - [C, &AA](const MemoryLocation &ML) { - return AA.getModRefInfo(C, ML); - }); + NeedLift = any_of(MemLocs, [C, &AA](const MemoryLocation &ML) { + return AA.getModRefInfo(C, ML); + }); if (!NeedLift) - NeedLift = std::any_of(CallSites.begin(), CallSites.end(), - [C, &AA](const ImmutableCallSite &CS) { - return AA.getModRefInfo(C, CS); - }); + NeedLift = any_of(CallSites, [C, &AA](const ImmutableCallSite &CS) { + return AA.getModRefInfo(C, CS); + }); } if (!NeedLift) continue; if (MayAlias) { - if (auto CS = ImmutableCallSite(C)) { + // Since LI is implicitly moved downwards past the lifted instructions, + // none of them may modify its source. + if (AA.getModRefInfo(C, LoadLoc) & MRI_Mod) + return false; + else if (auto CS = ImmutableCallSite(C)) { // If we can't lift this before P, it's game over. if (AA.getModRefInfo(P, CS) != MRI_NoModRef) return false; @@ -612,7 +616,7 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { // position if nothing alias the store memory after this and the store // destination is not in the range. if (P && P != SI) { - if (!moveUp(AA, SI, P)) + if (!moveUp(AA, SI, P, LI)) P = nullptr; } @@ -1082,10 +1086,10 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, DestSize = Builder.CreateZExt(DestSize, SrcSize->getType()); } - Value *MemsetLen = - Builder.CreateSelect(Builder.CreateICmpULE(DestSize, SrcSize), - ConstantInt::getNullValue(DestSize->getType()), - Builder.CreateSub(DestSize, SrcSize)); + Value *Ule = Builder.CreateICmpULE(DestSize, SrcSize); + Value *SizeDiff = Builder.CreateSub(DestSize, SrcSize); + Value *MemsetLen = Builder.CreateSelect( + Ule, ConstantInt::getNullValue(DestSize->getType()), SizeDiff); Builder.CreateMemSet(Builder.CreateGEP(Dest, SrcSize), MemSet->getOperand(1), MemsetLen, Align); @@ -1110,8 +1114,11 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, /// The \p MemCpy must have a Constant length. bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, MemSetInst *MemSet) { - // This only makes sense on memcpy(..., memset(...), ...). - if (MemSet->getRawDest() != MemCpy->getRawSource()) + AliasAnalysis &AA = LookupAliasAnalysis(); + + // Make sure that memcpy(..., memset(...), ...), that is we are memsetting and + // memcpying from the same address. Otherwise it is hard to reason about. + if (!AA.isMustAlias(MemSet->getRawDest(), MemCpy->getRawSource())) return false; ConstantInt *CopySize = cast<ConstantInt>(MemCpy->getLength()); diff --git a/contrib/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/contrib/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index 30261b7..6a64c6b 100644 --- a/contrib/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -260,7 +260,7 @@ void MergedLoadStoreMotion::hoistInstruction(BasicBlock *BB, assert(HoistCand->getParent() != BB); // Intersect optional metadata. - HoistCand->intersectOptionalDataWith(ElseInst); + HoistCand->andIRFlags(ElseInst); HoistCand->dropUnknownNonDebugMetadata(); // Prepend point for instruction insert @@ -434,7 +434,7 @@ bool MergedLoadStoreMotion::sinkStore(BasicBlock *BB, StoreInst *S0, // Hoist the instruction. BasicBlock::iterator InsertPt = BB->getFirstInsertionPt(); // Intersect optional metadata. - S0->intersectOptionalDataWith(S1); + S0->andIRFlags(S1); S0->dropUnknownNonDebugMetadata(); // Create the new store to be inserted at the join point. @@ -563,7 +563,6 @@ public: } private: - // This transformation requires dominator postdominator info void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired<AAResultsWrapperPass>(); @@ -590,7 +589,7 @@ INITIALIZE_PASS_END(MergedLoadStoreMotionLegacyPass, "mldst-motion", "MergedLoadStoreMotion", false, false) PreservedAnalyses -MergedLoadStoreMotionPass::run(Function &F, AnalysisManager<Function> &AM) { +MergedLoadStoreMotionPass::run(Function &F, FunctionAnalysisManager &AM) { MergedLoadStoreMotion Impl; auto *MD = AM.getCachedResult<MemoryDependenceAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); diff --git a/contrib/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/contrib/llvm/lib/Transforms/Scalar/NaryReassociate.cpp index ed754fa..0a3bf7b 100644 --- a/contrib/llvm/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -76,12 +76,8 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Transforms/Scalar/NaryReassociate.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Dominators.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" @@ -94,16 +90,15 @@ using namespace PatternMatch; #define DEBUG_TYPE "nary-reassociate" namespace { -class NaryReassociate : public FunctionPass { +class NaryReassociateLegacyPass : public FunctionPass { public: static char ID; - NaryReassociate(): FunctionPass(ID) { - initializeNaryReassociatePass(*PassRegistry::getPassRegistry()); + NaryReassociateLegacyPass() : FunctionPass(ID) { + initializeNaryReassociateLegacyPassPass(*PassRegistry::getPassRegistry()); } bool doInitialization(Module &M) override { - DL = &M.getDataLayout(); return false; } bool runOnFunction(Function &F) override; @@ -121,101 +116,73 @@ public: } private: - // Runs only one iteration of the dominator-based algorithm. See the header - // comments for why we need multiple iterations. - bool doOneIteration(Function &F); - - // Reassociates I for better CSE. - Instruction *tryReassociate(Instruction *I); - - // Reassociate GEP for better CSE. - Instruction *tryReassociateGEP(GetElementPtrInst *GEP); - // Try splitting GEP at the I-th index and see whether either part can be - // CSE'ed. This is a helper function for tryReassociateGEP. - // - // \p IndexedType The element type indexed by GEP's I-th index. This is - // equivalent to - // GEP->getIndexedType(GEP->getPointerOperand(), 0-th index, - // ..., i-th index). - GetElementPtrInst *tryReassociateGEPAtIndex(GetElementPtrInst *GEP, - unsigned I, Type *IndexedType); - // Given GEP's I-th index = LHS + RHS, see whether &Base[..][LHS][..] or - // &Base[..][RHS][..] can be CSE'ed and rewrite GEP accordingly. - GetElementPtrInst *tryReassociateGEPAtIndex(GetElementPtrInst *GEP, - unsigned I, Value *LHS, - Value *RHS, Type *IndexedType); - - // Reassociate binary operators for better CSE. - Instruction *tryReassociateBinaryOp(BinaryOperator *I); - - // A helper function for tryReassociateBinaryOp. LHS and RHS are explicitly - // passed. - Instruction *tryReassociateBinaryOp(Value *LHS, Value *RHS, - BinaryOperator *I); - // Rewrites I to (LHS op RHS) if LHS is computed already. - Instruction *tryReassociatedBinaryOp(const SCEV *LHS, Value *RHS, - BinaryOperator *I); - - // Tries to match Op1 and Op2 by using V. - bool matchTernaryOp(BinaryOperator *I, Value *V, Value *&Op1, Value *&Op2); - - // Gets SCEV for (LHS op RHS). - const SCEV *getBinarySCEV(BinaryOperator *I, const SCEV *LHS, - const SCEV *RHS); - - // Returns the closest dominator of \c Dominatee that computes - // \c CandidateExpr. Returns null if not found. - Instruction *findClosestMatchingDominator(const SCEV *CandidateExpr, - Instruction *Dominatee); - // GetElementPtrInst implicitly sign-extends an index if the index is shorter - // than the pointer size. This function returns whether Index is shorter than - // GEP's pointer size, i.e., whether Index needs to be sign-extended in order - // to be an index of GEP. - bool requiresSignExtension(Value *Index, GetElementPtrInst *GEP); - - AssumptionCache *AC; - const DataLayout *DL; - DominatorTree *DT; - ScalarEvolution *SE; - TargetLibraryInfo *TLI; - TargetTransformInfo *TTI; - // A lookup table quickly telling which instructions compute the given SCEV. - // Note that there can be multiple instructions at different locations - // computing to the same SCEV, so we map a SCEV to an instruction list. For - // example, - // - // if (p1) - // foo(a + b); - // if (p2) - // bar(a + b); - DenseMap<const SCEV *, SmallVector<WeakVH, 2>> SeenExprs; + NaryReassociatePass Impl; }; } // anonymous namespace -char NaryReassociate::ID = 0; -INITIALIZE_PASS_BEGIN(NaryReassociate, "nary-reassociate", "Nary reassociation", - false, false) +char NaryReassociateLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(NaryReassociateLegacyPass, "nary-reassociate", + "Nary reassociation", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(NaryReassociate, "nary-reassociate", "Nary reassociation", - false, false) +INITIALIZE_PASS_END(NaryReassociateLegacyPass, "nary-reassociate", + "Nary reassociation", false, false) FunctionPass *llvm::createNaryReassociatePass() { - return new NaryReassociate(); + return new NaryReassociateLegacyPass(); } -bool NaryReassociate::runOnFunction(Function &F) { +bool NaryReassociateLegacyPass::runOnFunction(Function &F) { if (skipFunction(F)) return false; - AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + + return Impl.runImpl(F, AC, DT, SE, TLI, TTI); +} + +PreservedAnalyses NaryReassociatePass::run(Function &F, + FunctionAnalysisManager &AM) { + auto *AC = &AM.getResult<AssumptionAnalysis>(F); + auto *DT = &AM.getResult<DominatorTreeAnalysis>(F); + auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F); + auto *TLI = &AM.getResult<TargetLibraryAnalysis>(F); + auto *TTI = &AM.getResult<TargetIRAnalysis>(F); + + bool Changed = runImpl(F, AC, DT, SE, TLI, TTI); + + // FIXME: We need to invalidate this to avoid PR28400. Is there a better + // solution? + AM.invalidate<ScalarEvolutionAnalysis>(F); + + if (!Changed) + return PreservedAnalyses::all(); + + // FIXME: This should also 'preserve the CFG'. + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<ScalarEvolutionAnalysis>(); + PA.preserve<TargetLibraryAnalysis>(); + return PA; +} + +bool NaryReassociatePass::runImpl(Function &F, AssumptionCache *AC_, + DominatorTree *DT_, ScalarEvolution *SE_, + TargetLibraryInfo *TLI_, + TargetTransformInfo *TTI_) { + AC = AC_; + DT = DT_; + SE = SE_; + TLI = TLI_; + TTI = TTI_; + DL = &F.getParent()->getDataLayout(); bool Changed = false, ChangedInThisIteration; do { @@ -237,13 +204,13 @@ static bool isPotentiallyNaryReassociable(Instruction *I) { } } -bool NaryReassociate::doOneIteration(Function &F) { +bool NaryReassociatePass::doOneIteration(Function &F) { bool Changed = false; SeenExprs.clear(); - // Process the basic blocks in pre-order of the dominator tree. This order - // ensures that all bases of a candidate are in Candidates when we process it. - for (auto Node = GraphTraits<DominatorTree *>::nodes_begin(DT); - Node != GraphTraits<DominatorTree *>::nodes_end(DT); ++Node) { + // Process the basic blocks in a depth first traversal of the dominator + // tree. This order ensures that all bases of a candidate are in Candidates + // when we process it. + for (const auto Node : depth_first(DT)) { BasicBlock *BB = Node->getBlock(); for (auto I = BB->begin(); I != BB->end(); ++I) { if (SE->isSCEVable(I->getType()) && isPotentiallyNaryReassociable(&*I)) { @@ -287,7 +254,7 @@ bool NaryReassociate::doOneIteration(Function &F) { return Changed; } -Instruction *NaryReassociate::tryReassociate(Instruction *I) { +Instruction *NaryReassociatePass::tryReassociate(Instruction *I) { switch (I->getOpcode()) { case Instruction::Add: case Instruction::Mul: @@ -308,15 +275,16 @@ static bool isGEPFoldable(GetElementPtrInst *GEP, Indices) == TargetTransformInfo::TCC_Free; } -Instruction *NaryReassociate::tryReassociateGEP(GetElementPtrInst *GEP) { +Instruction *NaryReassociatePass::tryReassociateGEP(GetElementPtrInst *GEP) { // Not worth reassociating GEP if it is foldable. if (isGEPFoldable(GEP, TTI)) return nullptr; gep_type_iterator GTI = gep_type_begin(*GEP); - for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I) { - if (isa<SequentialType>(*GTI++)) { - if (auto *NewGEP = tryReassociateGEPAtIndex(GEP, I - 1, *GTI)) { + for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { + if (GTI.isSequential()) { + if (auto *NewGEP = tryReassociateGEPAtIndex(GEP, I - 1, + GTI.getIndexedType())) { return NewGEP; } } @@ -324,16 +292,16 @@ Instruction *NaryReassociate::tryReassociateGEP(GetElementPtrInst *GEP) { return nullptr; } -bool NaryReassociate::requiresSignExtension(Value *Index, - GetElementPtrInst *GEP) { +bool NaryReassociatePass::requiresSignExtension(Value *Index, + GetElementPtrInst *GEP) { unsigned PointerSizeInBits = DL->getPointerSizeInBits(GEP->getType()->getPointerAddressSpace()); return cast<IntegerType>(Index->getType())->getBitWidth() < PointerSizeInBits; } GetElementPtrInst * -NaryReassociate::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, unsigned I, - Type *IndexedType) { +NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, + unsigned I, Type *IndexedType) { Value *IndexToSplit = GEP->getOperand(I + 1); if (SExtInst *SExt = dyn_cast<SExtInst>(IndexToSplit)) { IndexToSplit = SExt->getOperand(0); @@ -366,9 +334,10 @@ NaryReassociate::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, unsigned I, return nullptr; } -GetElementPtrInst *NaryReassociate::tryReassociateGEPAtIndex( - GetElementPtrInst *GEP, unsigned I, Value *LHS, Value *RHS, - Type *IndexedType) { +GetElementPtrInst * +NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, + unsigned I, Value *LHS, + Value *RHS, Type *IndexedType) { // Look for GEP's closest dominator that has the same SCEV as GEP except that // the I-th index is replaced with LHS. SmallVector<const SCEV *, 4> IndexExprs; @@ -386,9 +355,8 @@ GetElementPtrInst *NaryReassociate::tryReassociateGEPAtIndex( IndexExprs[I] = SE->getZeroExtendExpr(IndexExprs[I], GEP->getOperand(I)->getType()); } - const SCEV *CandidateExpr = SE->getGEPExpr( - GEP->getSourceElementType(), SE->getSCEV(GEP->getPointerOperand()), - IndexExprs, GEP->isInBounds()); + const SCEV *CandidateExpr = SE->getGEPExpr(cast<GEPOperator>(GEP), + IndexExprs); Value *Candidate = findClosestMatchingDominator(CandidateExpr, GEP); if (Candidate == nullptr) @@ -437,7 +405,7 @@ GetElementPtrInst *NaryReassociate::tryReassociateGEPAtIndex( return NewGEP; } -Instruction *NaryReassociate::tryReassociateBinaryOp(BinaryOperator *I) { +Instruction *NaryReassociatePass::tryReassociateBinaryOp(BinaryOperator *I) { Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); if (auto *NewI = tryReassociateBinaryOp(LHS, RHS, I)) return NewI; @@ -446,8 +414,8 @@ Instruction *NaryReassociate::tryReassociateBinaryOp(BinaryOperator *I) { return nullptr; } -Instruction *NaryReassociate::tryReassociateBinaryOp(Value *LHS, Value *RHS, - BinaryOperator *I) { +Instruction *NaryReassociatePass::tryReassociateBinaryOp(Value *LHS, Value *RHS, + BinaryOperator *I) { Value *A = nullptr, *B = nullptr; // To be conservative, we reassociate I only when it is the only user of (A op // B). @@ -470,9 +438,9 @@ Instruction *NaryReassociate::tryReassociateBinaryOp(Value *LHS, Value *RHS, return nullptr; } -Instruction *NaryReassociate::tryReassociatedBinaryOp(const SCEV *LHSExpr, - Value *RHS, - BinaryOperator *I) { +Instruction *NaryReassociatePass::tryReassociatedBinaryOp(const SCEV *LHSExpr, + Value *RHS, + BinaryOperator *I) { // Look for the closest dominator LHS of I that computes LHSExpr, and replace // I with LHS op RHS. auto *LHS = findClosestMatchingDominator(LHSExpr, I); @@ -494,8 +462,8 @@ Instruction *NaryReassociate::tryReassociatedBinaryOp(const SCEV *LHSExpr, return NewI; } -bool NaryReassociate::matchTernaryOp(BinaryOperator *I, Value *V, Value *&Op1, - Value *&Op2) { +bool NaryReassociatePass::matchTernaryOp(BinaryOperator *I, Value *V, + Value *&Op1, Value *&Op2) { switch (I->getOpcode()) { case Instruction::Add: return match(V, m_Add(m_Value(Op1), m_Value(Op2))); @@ -507,8 +475,9 @@ bool NaryReassociate::matchTernaryOp(BinaryOperator *I, Value *V, Value *&Op1, return false; } -const SCEV *NaryReassociate::getBinarySCEV(BinaryOperator *I, const SCEV *LHS, - const SCEV *RHS) { +const SCEV *NaryReassociatePass::getBinarySCEV(BinaryOperator *I, + const SCEV *LHS, + const SCEV *RHS) { switch (I->getOpcode()) { case Instruction::Add: return SE->getAddExpr(LHS, RHS); @@ -521,8 +490,8 @@ const SCEV *NaryReassociate::getBinarySCEV(BinaryOperator *I, const SCEV *LHS, } Instruction * -NaryReassociate::findClosestMatchingDominator(const SCEV *CandidateExpr, - Instruction *Dominatee) { +NaryReassociatePass::findClosestMatchingDominator(const SCEV *CandidateExpr, + Instruction *Dominatee) { auto Pos = SeenExprs.find(CandidateExpr); if (Pos == SeenExprs.end()) return nullptr; diff --git a/contrib/llvm/lib/Transforms/Scalar/NewGVN.cpp b/contrib/llvm/lib/Transforms/Scalar/NewGVN.cpp new file mode 100644 index 0000000..57e6e3d --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -0,0 +1,2257 @@ +//===---- NewGVN.cpp - Global Value Numbering Pass --------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// \file +/// This file implements the new LLVM's Global Value Numbering pass. +/// GVN partitions values computed by a function into congruence classes. +/// Values ending up in the same congruence class are guaranteed to be the same +/// for every execution of the program. In that respect, congruency is a +/// compile-time approximation of equivalence of values at runtime. +/// The algorithm implemented here uses a sparse formulation and it's based +/// on the ideas described in the paper: +/// "A Sparse Algorithm for Predicated Global Value Numbering" from +/// Karthik Gargi. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/NewGVN.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SparseBitVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/TinyPtrVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/CFGPrinter.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/PHITransAddr.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/PredIteratorCache.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVNExpression.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/MemorySSA.h" +#include "llvm/Transforms/Utils/SSAUpdater.h" +#include <unordered_map> +#include <utility> +#include <vector> +using namespace llvm; +using namespace PatternMatch; +using namespace llvm::GVNExpression; + +#define DEBUG_TYPE "newgvn" + +STATISTIC(NumGVNInstrDeleted, "Number of instructions deleted"); +STATISTIC(NumGVNBlocksDeleted, "Number of blocks deleted"); +STATISTIC(NumGVNOpsSimplified, "Number of Expressions simplified"); +STATISTIC(NumGVNPhisAllSame, "Number of PHIs whos arguments are all the same"); +STATISTIC(NumGVNMaxIterations, + "Maximum Number of iterations it took to converge GVN"); +STATISTIC(NumGVNLeaderChanges, "Number of leader changes"); +STATISTIC(NumGVNSortedLeaderChanges, "Number of sorted leader changes"); +STATISTIC(NumGVNAvoidedSortedLeaderChanges, + "Number of avoided sorted leader changes"); +STATISTIC(NumGVNNotMostDominatingLeader, + "Number of times a member dominated it's new classes' leader"); + +//===----------------------------------------------------------------------===// +// GVN Pass +//===----------------------------------------------------------------------===// + +// Anchor methods. +namespace llvm { +namespace GVNExpression { +Expression::~Expression() = default; +BasicExpression::~BasicExpression() = default; +CallExpression::~CallExpression() = default; +LoadExpression::~LoadExpression() = default; +StoreExpression::~StoreExpression() = default; +AggregateValueExpression::~AggregateValueExpression() = default; +PHIExpression::~PHIExpression() = default; +} +} + +// Congruence classes represent the set of expressions/instructions +// that are all the same *during some scope in the function*. +// That is, because of the way we perform equality propagation, and +// because of memory value numbering, it is not correct to assume +// you can willy-nilly replace any member with any other at any +// point in the function. +// +// For any Value in the Member set, it is valid to replace any dominated member +// with that Value. +// +// Every congruence class has a leader, and the leader is used to +// symbolize instructions in a canonical way (IE every operand of an +// instruction that is a member of the same congruence class will +// always be replaced with leader during symbolization). +// To simplify symbolization, we keep the leader as a constant if class can be +// proved to be a constant value. +// Otherwise, the leader is a randomly chosen member of the value set, it does +// not matter which one is chosen. +// Each congruence class also has a defining expression, +// though the expression may be null. If it exists, it can be used for forward +// propagation and reassociation of values. +// +struct CongruenceClass { + using MemberSet = SmallPtrSet<Value *, 4>; + unsigned ID; + // Representative leader. + Value *RepLeader = nullptr; + // Defining Expression. + const Expression *DefiningExpr = nullptr; + // Actual members of this class. + MemberSet Members; + + // True if this class has no members left. This is mainly used for assertion + // purposes, and for skipping empty classes. + bool Dead = false; + + // Number of stores in this congruence class. + // This is used so we can detect store equivalence changes properly. + int StoreCount = 0; + + // The most dominating leader after our current leader, because the member set + // is not sorted and is expensive to keep sorted all the time. + std::pair<Value *, unsigned int> NextLeader = {nullptr, ~0U}; + + explicit CongruenceClass(unsigned ID) : ID(ID) {} + CongruenceClass(unsigned ID, Value *Leader, const Expression *E) + : ID(ID), RepLeader(Leader), DefiningExpr(E) {} +}; + +namespace llvm { +template <> struct DenseMapInfo<const Expression *> { + static const Expression *getEmptyKey() { + auto Val = static_cast<uintptr_t>(-1); + Val <<= PointerLikeTypeTraits<const Expression *>::NumLowBitsAvailable; + return reinterpret_cast<const Expression *>(Val); + } + static const Expression *getTombstoneKey() { + auto Val = static_cast<uintptr_t>(~1U); + Val <<= PointerLikeTypeTraits<const Expression *>::NumLowBitsAvailable; + return reinterpret_cast<const Expression *>(Val); + } + static unsigned getHashValue(const Expression *V) { + return static_cast<unsigned>(V->getHashValue()); + } + static bool isEqual(const Expression *LHS, const Expression *RHS) { + if (LHS == RHS) + return true; + if (LHS == getTombstoneKey() || RHS == getTombstoneKey() || + LHS == getEmptyKey() || RHS == getEmptyKey()) + return false; + return *LHS == *RHS; + } +}; +} // end namespace llvm + +class NewGVN : public FunctionPass { + DominatorTree *DT; + const DataLayout *DL; + const TargetLibraryInfo *TLI; + AssumptionCache *AC; + AliasAnalysis *AA; + MemorySSA *MSSA; + MemorySSAWalker *MSSAWalker; + BumpPtrAllocator ExpressionAllocator; + ArrayRecycler<Value *> ArgRecycler; + + // Congruence class info. + CongruenceClass *InitialClass; + std::vector<CongruenceClass *> CongruenceClasses; + unsigned NextCongruenceNum; + + // Value Mappings. + DenseMap<Value *, CongruenceClass *> ValueToClass; + DenseMap<Value *, const Expression *> ValueToExpression; + + // A table storing which memorydefs/phis represent a memory state provably + // equivalent to another memory state. + // We could use the congruence class machinery, but the MemoryAccess's are + // abstract memory states, so they can only ever be equivalent to each other, + // and not to constants, etc. + DenseMap<const MemoryAccess *, MemoryAccess *> MemoryAccessEquiv; + + // Expression to class mapping. + using ExpressionClassMap = DenseMap<const Expression *, CongruenceClass *>; + ExpressionClassMap ExpressionToClass; + + // Which values have changed as a result of leader changes. + SmallPtrSet<Value *, 8> LeaderChanges; + + // Reachability info. + using BlockEdge = BasicBlockEdge; + DenseSet<BlockEdge> ReachableEdges; + SmallPtrSet<const BasicBlock *, 8> ReachableBlocks; + + // This is a bitvector because, on larger functions, we may have + // thousands of touched instructions at once (entire blocks, + // instructions with hundreds of uses, etc). Even with optimization + // for when we mark whole blocks as touched, when this was a + // SmallPtrSet or DenseSet, for some functions, we spent >20% of all + // the time in GVN just managing this list. The bitvector, on the + // other hand, efficiently supports test/set/clear of both + // individual and ranges, as well as "find next element" This + // enables us to use it as a worklist with essentially 0 cost. + BitVector TouchedInstructions; + + DenseMap<const BasicBlock *, std::pair<unsigned, unsigned>> BlockInstRange; + DenseMap<const DomTreeNode *, std::pair<unsigned, unsigned>> + DominatedInstRange; + +#ifndef NDEBUG + // Debugging for how many times each block and instruction got processed. + DenseMap<const Value *, unsigned> ProcessedCount; +#endif + + // DFS info. + DenseMap<const BasicBlock *, std::pair<int, int>> DFSDomMap; + DenseMap<const Value *, unsigned> InstrDFS; + SmallVector<Value *, 32> DFSToInstr; + + // Deletion info. + SmallPtrSet<Instruction *, 8> InstructionsToErase; + +public: + static char ID; // Pass identification, replacement for typeid. + NewGVN() : FunctionPass(ID) { + initializeNewGVNPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + bool runGVN(Function &F, DominatorTree *DT, AssumptionCache *AC, + TargetLibraryInfo *TLI, AliasAnalysis *AA, MemorySSA *MSSA); + +private: + // This transformation requires dominator postdominator info. + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<MemorySSAWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } + + // Expression handling. + const Expression *createExpression(Instruction *, const BasicBlock *); + const Expression *createBinaryExpression(unsigned, Type *, Value *, Value *, + const BasicBlock *); + PHIExpression *createPHIExpression(Instruction *); + const VariableExpression *createVariableExpression(Value *); + const ConstantExpression *createConstantExpression(Constant *); + const Expression *createVariableOrConstant(Value *V, const BasicBlock *B); + const UnknownExpression *createUnknownExpression(Instruction *); + const StoreExpression *createStoreExpression(StoreInst *, MemoryAccess *, + const BasicBlock *); + LoadExpression *createLoadExpression(Type *, Value *, LoadInst *, + MemoryAccess *, const BasicBlock *); + + const CallExpression *createCallExpression(CallInst *, MemoryAccess *, + const BasicBlock *); + const AggregateValueExpression * + createAggregateValueExpression(Instruction *, const BasicBlock *); + bool setBasicExpressionInfo(Instruction *, BasicExpression *, + const BasicBlock *); + + // Congruence class handling. + CongruenceClass *createCongruenceClass(Value *Leader, const Expression *E) { + auto *result = new CongruenceClass(NextCongruenceNum++, Leader, E); + CongruenceClasses.emplace_back(result); + return result; + } + + CongruenceClass *createSingletonCongruenceClass(Value *Member) { + CongruenceClass *CClass = createCongruenceClass(Member, nullptr); + CClass->Members.insert(Member); + ValueToClass[Member] = CClass; + return CClass; + } + void initializeCongruenceClasses(Function &F); + + // Value number an Instruction or MemoryPhi. + void valueNumberMemoryPhi(MemoryPhi *); + void valueNumberInstruction(Instruction *); + + // Symbolic evaluation. + const Expression *checkSimplificationResults(Expression *, Instruction *, + Value *); + const Expression *performSymbolicEvaluation(Value *, const BasicBlock *); + const Expression *performSymbolicLoadEvaluation(Instruction *, + const BasicBlock *); + const Expression *performSymbolicStoreEvaluation(Instruction *, + const BasicBlock *); + const Expression *performSymbolicCallEvaluation(Instruction *, + const BasicBlock *); + const Expression *performSymbolicPHIEvaluation(Instruction *, + const BasicBlock *); + bool setMemoryAccessEquivTo(MemoryAccess *From, MemoryAccess *To); + const Expression *performSymbolicAggrValueEvaluation(Instruction *, + const BasicBlock *); + + // Congruence finding. + // Templated to allow them to work both on BB's and BB-edges. + template <class T> + Value *lookupOperandLeader(Value *, const User *, const T &) const; + void performCongruenceFinding(Instruction *, const Expression *); + void moveValueToNewCongruenceClass(Instruction *, CongruenceClass *, + CongruenceClass *); + // Reachability handling. + void updateReachableEdge(BasicBlock *, BasicBlock *); + void processOutgoingEdges(TerminatorInst *, BasicBlock *); + bool isOnlyReachableViaThisEdge(const BasicBlockEdge &) const; + Value *findConditionEquivalence(Value *, BasicBlock *) const; + MemoryAccess *lookupMemoryAccessEquiv(MemoryAccess *) const; + + // Elimination. + struct ValueDFS; + void convertDenseToDFSOrdered(CongruenceClass::MemberSet &, + SmallVectorImpl<ValueDFS> &); + + bool eliminateInstructions(Function &); + void replaceInstruction(Instruction *, Value *); + void markInstructionForDeletion(Instruction *); + void deleteInstructionsInBlock(BasicBlock *); + + // New instruction creation. + void handleNewInstruction(Instruction *){}; + + // Various instruction touch utilities + void markUsersTouched(Value *); + void markMemoryUsersTouched(MemoryAccess *); + void markLeaderChangeTouched(CongruenceClass *CC); + + // Utilities. + void cleanupTables(); + std::pair<unsigned, unsigned> assignDFSNumbers(BasicBlock *, unsigned); + void updateProcessedCount(Value *V); + void verifyMemoryCongruency() const; + bool singleReachablePHIPath(const MemoryAccess *, const MemoryAccess *) const; +}; + +char NewGVN::ID = 0; + +// createGVNPass - The public interface to this file. +FunctionPass *llvm::createNewGVNPass() { return new NewGVN(); } + +template <typename T> +static bool equalsLoadStoreHelper(const T &LHS, const Expression &RHS) { + if ((!isa<LoadExpression>(RHS) && !isa<StoreExpression>(RHS)) || + !LHS.BasicExpression::equals(RHS)) { + return false; + } else if (const auto *L = dyn_cast<LoadExpression>(&RHS)) { + if (LHS.getDefiningAccess() != L->getDefiningAccess()) + return false; + } else if (const auto *S = dyn_cast<StoreExpression>(&RHS)) { + if (LHS.getDefiningAccess() != S->getDefiningAccess()) + return false; + } + return true; +} + +bool LoadExpression::equals(const Expression &Other) const { + return equalsLoadStoreHelper(*this, Other); +} + +bool StoreExpression::equals(const Expression &Other) const { + return equalsLoadStoreHelper(*this, Other); +} + +#ifndef NDEBUG +static std::string getBlockName(const BasicBlock *B) { + return DOTGraphTraits<const Function *>::getSimpleNodeLabel(B, nullptr); +} +#endif + +INITIALIZE_PASS_BEGIN(NewGVN, "newgvn", "Global Value Numbering", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_END(NewGVN, "newgvn", "Global Value Numbering", false, false) + +PHIExpression *NewGVN::createPHIExpression(Instruction *I) { + BasicBlock *PHIBlock = I->getParent(); + auto *PN = cast<PHINode>(I); + auto *E = + new (ExpressionAllocator) PHIExpression(PN->getNumOperands(), PHIBlock); + + E->allocateOperands(ArgRecycler, ExpressionAllocator); + E->setType(I->getType()); + E->setOpcode(I->getOpcode()); + + auto ReachablePhiArg = [&](const Use &U) { + return ReachableBlocks.count(PN->getIncomingBlock(U)); + }; + + // Filter out unreachable operands + auto Filtered = make_filter_range(PN->operands(), ReachablePhiArg); + + std::transform(Filtered.begin(), Filtered.end(), op_inserter(E), + [&](const Use &U) -> Value * { + // Don't try to transform self-defined phis. + if (U == PN) + return PN; + const BasicBlockEdge BBE(PN->getIncomingBlock(U), PHIBlock); + return lookupOperandLeader(U, I, BBE); + }); + return E; +} + +// Set basic expression info (Arguments, type, opcode) for Expression +// E from Instruction I in block B. +bool NewGVN::setBasicExpressionInfo(Instruction *I, BasicExpression *E, + const BasicBlock *B) { + bool AllConstant = true; + if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) + E->setType(GEP->getSourceElementType()); + else + E->setType(I->getType()); + E->setOpcode(I->getOpcode()); + E->allocateOperands(ArgRecycler, ExpressionAllocator); + + // Transform the operand array into an operand leader array, and keep track of + // whether all members are constant. + std::transform(I->op_begin(), I->op_end(), op_inserter(E), [&](Value *O) { + auto Operand = lookupOperandLeader(O, I, B); + AllConstant &= isa<Constant>(Operand); + return Operand; + }); + + return AllConstant; +} + +const Expression *NewGVN::createBinaryExpression(unsigned Opcode, Type *T, + Value *Arg1, Value *Arg2, + const BasicBlock *B) { + auto *E = new (ExpressionAllocator) BasicExpression(2); + + E->setType(T); + E->setOpcode(Opcode); + E->allocateOperands(ArgRecycler, ExpressionAllocator); + if (Instruction::isCommutative(Opcode)) { + // Ensure that commutative instructions that only differ by a permutation + // of their operands get the same value number by sorting the operand value + // numbers. Since all commutative instructions have two operands it is more + // efficient to sort by hand rather than using, say, std::sort. + if (Arg1 > Arg2) + std::swap(Arg1, Arg2); + } + E->op_push_back(lookupOperandLeader(Arg1, nullptr, B)); + E->op_push_back(lookupOperandLeader(Arg2, nullptr, B)); + + Value *V = SimplifyBinOp(Opcode, E->getOperand(0), E->getOperand(1), *DL, TLI, + DT, AC); + if (const Expression *SimplifiedE = checkSimplificationResults(E, nullptr, V)) + return SimplifiedE; + return E; +} + +// Take a Value returned by simplification of Expression E/Instruction +// I, and see if it resulted in a simpler expression. If so, return +// that expression. +// TODO: Once finished, this should not take an Instruction, we only +// use it for printing. +const Expression *NewGVN::checkSimplificationResults(Expression *E, + Instruction *I, Value *V) { + if (!V) + return nullptr; + if (auto *C = dyn_cast<Constant>(V)) { + if (I) + DEBUG(dbgs() << "Simplified " << *I << " to " + << " constant " << *C << "\n"); + NumGVNOpsSimplified++; + assert(isa<BasicExpression>(E) && + "We should always have had a basic expression here"); + + cast<BasicExpression>(E)->deallocateOperands(ArgRecycler); + ExpressionAllocator.Deallocate(E); + return createConstantExpression(C); + } else if (isa<Argument>(V) || isa<GlobalVariable>(V)) { + if (I) + DEBUG(dbgs() << "Simplified " << *I << " to " + << " variable " << *V << "\n"); + cast<BasicExpression>(E)->deallocateOperands(ArgRecycler); + ExpressionAllocator.Deallocate(E); + return createVariableExpression(V); + } + + CongruenceClass *CC = ValueToClass.lookup(V); + if (CC && CC->DefiningExpr) { + if (I) + DEBUG(dbgs() << "Simplified " << *I << " to " + << " expression " << *V << "\n"); + NumGVNOpsSimplified++; + assert(isa<BasicExpression>(E) && + "We should always have had a basic expression here"); + cast<BasicExpression>(E)->deallocateOperands(ArgRecycler); + ExpressionAllocator.Deallocate(E); + return CC->DefiningExpr; + } + return nullptr; +} + +const Expression *NewGVN::createExpression(Instruction *I, + const BasicBlock *B) { + + auto *E = new (ExpressionAllocator) BasicExpression(I->getNumOperands()); + + bool AllConstant = setBasicExpressionInfo(I, E, B); + + if (I->isCommutative()) { + // Ensure that commutative instructions that only differ by a permutation + // of their operands get the same value number by sorting the operand value + // numbers. Since all commutative instructions have two operands it is more + // efficient to sort by hand rather than using, say, std::sort. + assert(I->getNumOperands() == 2 && "Unsupported commutative instruction!"); + if (E->getOperand(0) > E->getOperand(1)) + E->swapOperands(0, 1); + } + + // Perform simplificaiton + // TODO: Right now we only check to see if we get a constant result. + // We may get a less than constant, but still better, result for + // some operations. + // IE + // add 0, x -> x + // and x, x -> x + // We should handle this by simply rewriting the expression. + if (auto *CI = dyn_cast<CmpInst>(I)) { + // Sort the operand value numbers so x<y and y>x get the same value + // number. + CmpInst::Predicate Predicate = CI->getPredicate(); + if (E->getOperand(0) > E->getOperand(1)) { + E->swapOperands(0, 1); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } + E->setOpcode((CI->getOpcode() << 8) | Predicate); + // TODO: 25% of our time is spent in SimplifyCmpInst with pointer operands + // TODO: Since we noop bitcasts, we may need to check types before + // simplifying, so that we don't end up simplifying based on a wrong + // type assumption. We should clean this up so we can use constants of the + // wrong type + + assert(I->getOperand(0)->getType() == I->getOperand(1)->getType() && + "Wrong types on cmp instruction"); + if ((E->getOperand(0)->getType() == I->getOperand(0)->getType() && + E->getOperand(1)->getType() == I->getOperand(1)->getType())) { + Value *V = SimplifyCmpInst(Predicate, E->getOperand(0), E->getOperand(1), + *DL, TLI, DT, AC); + if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) + return SimplifiedE; + } + } else if (isa<SelectInst>(I)) { + if (isa<Constant>(E->getOperand(0)) || + (E->getOperand(1)->getType() == I->getOperand(1)->getType() && + E->getOperand(2)->getType() == I->getOperand(2)->getType())) { + Value *V = SimplifySelectInst(E->getOperand(0), E->getOperand(1), + E->getOperand(2), *DL, TLI, DT, AC); + if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) + return SimplifiedE; + } + } else if (I->isBinaryOp()) { + Value *V = SimplifyBinOp(E->getOpcode(), E->getOperand(0), E->getOperand(1), + *DL, TLI, DT, AC); + if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) + return SimplifiedE; + } else if (auto *BI = dyn_cast<BitCastInst>(I)) { + Value *V = SimplifyInstruction(BI, *DL, TLI, DT, AC); + if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) + return SimplifiedE; + } else if (isa<GetElementPtrInst>(I)) { + Value *V = SimplifyGEPInst(E->getType(), + ArrayRef<Value *>(E->op_begin(), E->op_end()), + *DL, TLI, DT, AC); + if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) + return SimplifiedE; + } else if (AllConstant) { + // We don't bother trying to simplify unless all of the operands + // were constant. + // TODO: There are a lot of Simplify*'s we could call here, if we + // wanted to. The original motivating case for this code was a + // zext i1 false to i8, which we don't have an interface to + // simplify (IE there is no SimplifyZExt). + + SmallVector<Constant *, 8> C; + for (Value *Arg : E->operands()) + C.emplace_back(cast<Constant>(Arg)); + + if (Value *V = ConstantFoldInstOperands(I, C, *DL, TLI)) + if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) + return SimplifiedE; + } + return E; +} + +const AggregateValueExpression * +NewGVN::createAggregateValueExpression(Instruction *I, const BasicBlock *B) { + if (auto *II = dyn_cast<InsertValueInst>(I)) { + auto *E = new (ExpressionAllocator) + AggregateValueExpression(I->getNumOperands(), II->getNumIndices()); + setBasicExpressionInfo(I, E, B); + E->allocateIntOperands(ExpressionAllocator); + std::copy(II->idx_begin(), II->idx_end(), int_op_inserter(E)); + return E; + } else if (auto *EI = dyn_cast<ExtractValueInst>(I)) { + auto *E = new (ExpressionAllocator) + AggregateValueExpression(I->getNumOperands(), EI->getNumIndices()); + setBasicExpressionInfo(EI, E, B); + E->allocateIntOperands(ExpressionAllocator); + std::copy(EI->idx_begin(), EI->idx_end(), int_op_inserter(E)); + return E; + } + llvm_unreachable("Unhandled type of aggregate value operation"); +} + +const VariableExpression *NewGVN::createVariableExpression(Value *V) { + auto *E = new (ExpressionAllocator) VariableExpression(V); + E->setOpcode(V->getValueID()); + return E; +} + +const Expression *NewGVN::createVariableOrConstant(Value *V, + const BasicBlock *B) { + auto Leader = lookupOperandLeader(V, nullptr, B); + if (auto *C = dyn_cast<Constant>(Leader)) + return createConstantExpression(C); + return createVariableExpression(Leader); +} + +const ConstantExpression *NewGVN::createConstantExpression(Constant *C) { + auto *E = new (ExpressionAllocator) ConstantExpression(C); + E->setOpcode(C->getValueID()); + return E; +} + +const UnknownExpression *NewGVN::createUnknownExpression(Instruction *I) { + auto *E = new (ExpressionAllocator) UnknownExpression(I); + E->setOpcode(I->getOpcode()); + return E; +} + +const CallExpression *NewGVN::createCallExpression(CallInst *CI, + MemoryAccess *HV, + const BasicBlock *B) { + // FIXME: Add operand bundles for calls. + auto *E = + new (ExpressionAllocator) CallExpression(CI->getNumOperands(), CI, HV); + setBasicExpressionInfo(CI, E, B); + return E; +} + +// See if we have a congruence class and leader for this operand, and if so, +// return it. Otherwise, return the operand itself. +template <class T> +Value *NewGVN::lookupOperandLeader(Value *V, const User *U, const T &B) const { + CongruenceClass *CC = ValueToClass.lookup(V); + if (CC && (CC != InitialClass)) + return CC->RepLeader; + return V; +} + +MemoryAccess *NewGVN::lookupMemoryAccessEquiv(MemoryAccess *MA) const { + MemoryAccess *Result = MemoryAccessEquiv.lookup(MA); + return Result ? Result : MA; +} + +LoadExpression *NewGVN::createLoadExpression(Type *LoadType, Value *PointerOp, + LoadInst *LI, MemoryAccess *DA, + const BasicBlock *B) { + auto *E = new (ExpressionAllocator) LoadExpression(1, LI, DA); + E->allocateOperands(ArgRecycler, ExpressionAllocator); + E->setType(LoadType); + + // Give store and loads same opcode so they value number together. + E->setOpcode(0); + E->op_push_back(lookupOperandLeader(PointerOp, LI, B)); + if (LI) + E->setAlignment(LI->getAlignment()); + + // TODO: Value number heap versions. We may be able to discover + // things alias analysis can't on it's own (IE that a store and a + // load have the same value, and thus, it isn't clobbering the load). + return E; +} + +const StoreExpression *NewGVN::createStoreExpression(StoreInst *SI, + MemoryAccess *DA, + const BasicBlock *B) { + auto *E = + new (ExpressionAllocator) StoreExpression(SI->getNumOperands(), SI, DA); + E->allocateOperands(ArgRecycler, ExpressionAllocator); + E->setType(SI->getValueOperand()->getType()); + + // Give store and loads same opcode so they value number together. + E->setOpcode(0); + E->op_push_back(lookupOperandLeader(SI->getPointerOperand(), SI, B)); + + // TODO: Value number heap versions. We may be able to discover + // things alias analysis can't on it's own (IE that a store and a + // load have the same value, and thus, it isn't clobbering the load). + return E; +} + +// Utility function to check whether the congruence class has a member other +// than the given instruction. +bool hasMemberOtherThanUs(const CongruenceClass *CC, Instruction *I) { + // Either it has more than one store, in which case it must contain something + // other than us (because it's indexed by value), or if it only has one store + // right now, that member should not be us. + return CC->StoreCount > 1 || CC->Members.count(I) == 0; +} + +const Expression *NewGVN::performSymbolicStoreEvaluation(Instruction *I, + const BasicBlock *B) { + // Unlike loads, we never try to eliminate stores, so we do not check if they + // are simple and avoid value numbering them. + auto *SI = cast<StoreInst>(I); + MemoryAccess *StoreAccess = MSSA->getMemoryAccess(SI); + // See if we are defined by a previous store expression, it already has a + // value, and it's the same value as our current store. FIXME: Right now, we + // only do this for simple stores, we should expand to cover memcpys, etc. + if (SI->isSimple()) { + // Get the expression, if any, for the RHS of the MemoryDef. + MemoryAccess *StoreRHS = lookupMemoryAccessEquiv( + cast<MemoryDef>(StoreAccess)->getDefiningAccess()); + const Expression *OldStore = createStoreExpression(SI, StoreRHS, B); + CongruenceClass *CC = ExpressionToClass.lookup(OldStore); + // Basically, check if the congruence class the store is in is defined by a + // store that isn't us, and has the same value. MemorySSA takes care of + // ensuring the store has the same memory state as us already. + if (CC && CC->DefiningExpr && isa<StoreExpression>(CC->DefiningExpr) && + CC->RepLeader == lookupOperandLeader(SI->getValueOperand(), SI, B) && + hasMemberOtherThanUs(CC, I)) + return createStoreExpression(SI, StoreRHS, B); + } + + return createStoreExpression(SI, StoreAccess, B); +} + +const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I, + const BasicBlock *B) { + auto *LI = cast<LoadInst>(I); + + // We can eliminate in favor of non-simple loads, but we won't be able to + // eliminate the loads themselves. + if (!LI->isSimple()) + return nullptr; + + Value *LoadAddressLeader = lookupOperandLeader(LI->getPointerOperand(), I, B); + // Load of undef is undef. + if (isa<UndefValue>(LoadAddressLeader)) + return createConstantExpression(UndefValue::get(LI->getType())); + + MemoryAccess *DefiningAccess = MSSAWalker->getClobberingMemoryAccess(I); + + if (!MSSA->isLiveOnEntryDef(DefiningAccess)) { + if (auto *MD = dyn_cast<MemoryDef>(DefiningAccess)) { + Instruction *DefiningInst = MD->getMemoryInst(); + // If the defining instruction is not reachable, replace with undef. + if (!ReachableBlocks.count(DefiningInst->getParent())) + return createConstantExpression(UndefValue::get(LI->getType())); + } + } + + const Expression *E = + createLoadExpression(LI->getType(), LI->getPointerOperand(), LI, + lookupMemoryAccessEquiv(DefiningAccess), B); + return E; +} + +// Evaluate read only and pure calls, and create an expression result. +const Expression *NewGVN::performSymbolicCallEvaluation(Instruction *I, + const BasicBlock *B) { + auto *CI = cast<CallInst>(I); + if (AA->doesNotAccessMemory(CI)) + return createCallExpression(CI, nullptr, B); + if (AA->onlyReadsMemory(CI)) { + MemoryAccess *DefiningAccess = MSSAWalker->getClobberingMemoryAccess(CI); + return createCallExpression(CI, lookupMemoryAccessEquiv(DefiningAccess), B); + } + return nullptr; +} + +// Update the memory access equivalence table to say that From is equal to To, +// and return true if this is different from what already existed in the table. +bool NewGVN::setMemoryAccessEquivTo(MemoryAccess *From, MemoryAccess *To) { + DEBUG(dbgs() << "Setting " << *From << " equivalent to "); + if (!To) + DEBUG(dbgs() << "itself"); + else + DEBUG(dbgs() << *To); + DEBUG(dbgs() << "\n"); + auto LookupResult = MemoryAccessEquiv.find(From); + bool Changed = false; + // If it's already in the table, see if the value changed. + if (LookupResult != MemoryAccessEquiv.end()) { + if (To && LookupResult->second != To) { + // It wasn't equivalent before, and now it is. + LookupResult->second = To; + Changed = true; + } else if (!To) { + // It used to be equivalent to something, and now it's not. + MemoryAccessEquiv.erase(LookupResult); + Changed = true; + } + } else { + assert(!To && + "Memory equivalence should never change from nothing to something"); + } + + return Changed; +} +// Evaluate PHI nodes symbolically, and create an expression result. +const Expression *NewGVN::performSymbolicPHIEvaluation(Instruction *I, + const BasicBlock *B) { + auto *E = cast<PHIExpression>(createPHIExpression(I)); + // We match the semantics of SimplifyPhiNode from InstructionSimplify here. + + // See if all arguaments are the same. + // We track if any were undef because they need special handling. + bool HasUndef = false; + auto Filtered = make_filter_range(E->operands(), [&](const Value *Arg) { + if (Arg == I) + return false; + if (isa<UndefValue>(Arg)) { + HasUndef = true; + return false; + } + return true; + }); + // If we are left with no operands, it's undef + if (Filtered.begin() == Filtered.end()) { + DEBUG(dbgs() << "Simplified PHI node " << *I << " to undef" + << "\n"); + E->deallocateOperands(ArgRecycler); + ExpressionAllocator.Deallocate(E); + return createConstantExpression(UndefValue::get(I->getType())); + } + Value *AllSameValue = *(Filtered.begin()); + ++Filtered.begin(); + // Can't use std::equal here, sadly, because filter.begin moves. + if (llvm::all_of(Filtered, [AllSameValue](const Value *V) { + return V == AllSameValue; + })) { + // In LLVM's non-standard representation of phi nodes, it's possible to have + // phi nodes with cycles (IE dependent on other phis that are .... dependent + // on the original phi node), especially in weird CFG's where some arguments + // are unreachable, or uninitialized along certain paths. This can cause + // infinite loops during evaluation. We work around this by not trying to + // really evaluate them independently, but instead using a variable + // expression to say if one is equivalent to the other. + // We also special case undef, so that if we have an undef, we can't use the + // common value unless it dominates the phi block. + if (HasUndef) { + // Only have to check for instructions + if (auto *AllSameInst = dyn_cast<Instruction>(AllSameValue)) + if (!DT->dominates(AllSameInst, I)) + return E; + } + + NumGVNPhisAllSame++; + DEBUG(dbgs() << "Simplified PHI node " << *I << " to " << *AllSameValue + << "\n"); + E->deallocateOperands(ArgRecycler); + ExpressionAllocator.Deallocate(E); + if (auto *C = dyn_cast<Constant>(AllSameValue)) + return createConstantExpression(C); + return createVariableExpression(AllSameValue); + } + return E; +} + +const Expression * +NewGVN::performSymbolicAggrValueEvaluation(Instruction *I, + const BasicBlock *B) { + if (auto *EI = dyn_cast<ExtractValueInst>(I)) { + auto *II = dyn_cast<IntrinsicInst>(EI->getAggregateOperand()); + if (II && EI->getNumIndices() == 1 && *EI->idx_begin() == 0) { + unsigned Opcode = 0; + // EI might be an extract from one of our recognised intrinsics. If it + // is we'll synthesize a semantically equivalent expression instead on + // an extract value expression. + switch (II->getIntrinsicID()) { + case Intrinsic::sadd_with_overflow: + case Intrinsic::uadd_with_overflow: + Opcode = Instruction::Add; + break; + case Intrinsic::ssub_with_overflow: + case Intrinsic::usub_with_overflow: + Opcode = Instruction::Sub; + break; + case Intrinsic::smul_with_overflow: + case Intrinsic::umul_with_overflow: + Opcode = Instruction::Mul; + break; + default: + break; + } + + if (Opcode != 0) { + // Intrinsic recognized. Grab its args to finish building the + // expression. + assert(II->getNumArgOperands() == 2 && + "Expect two args for recognised intrinsics."); + return createBinaryExpression(Opcode, EI->getType(), + II->getArgOperand(0), + II->getArgOperand(1), B); + } + } + } + + return createAggregateValueExpression(I, B); +} + +// Substitute and symbolize the value before value numbering. +const Expression *NewGVN::performSymbolicEvaluation(Value *V, + const BasicBlock *B) { + const Expression *E = nullptr; + if (auto *C = dyn_cast<Constant>(V)) + E = createConstantExpression(C); + else if (isa<Argument>(V) || isa<GlobalVariable>(V)) { + E = createVariableExpression(V); + } else { + // TODO: memory intrinsics. + // TODO: Some day, we should do the forward propagation and reassociation + // parts of the algorithm. + auto *I = cast<Instruction>(V); + switch (I->getOpcode()) { + case Instruction::ExtractValue: + case Instruction::InsertValue: + E = performSymbolicAggrValueEvaluation(I, B); + break; + case Instruction::PHI: + E = performSymbolicPHIEvaluation(I, B); + break; + case Instruction::Call: + E = performSymbolicCallEvaluation(I, B); + break; + case Instruction::Store: + E = performSymbolicStoreEvaluation(I, B); + break; + case Instruction::Load: + E = performSymbolicLoadEvaluation(I, B); + break; + case Instruction::BitCast: { + E = createExpression(I, B); + } break; + + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::ICmp: + case Instruction::FCmp: + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::Select: + case Instruction::ExtractElement: + case Instruction::InsertElement: + case Instruction::ShuffleVector: + case Instruction::GetElementPtr: + E = createExpression(I, B); + break; + default: + return nullptr; + } + } + return E; +} + +// There is an edge from 'Src' to 'Dst'. Return true if every path from +// the entry block to 'Dst' passes via this edge. In particular 'Dst' +// must not be reachable via another edge from 'Src'. +bool NewGVN::isOnlyReachableViaThisEdge(const BasicBlockEdge &E) const { + + // While in theory it is interesting to consider the case in which Dst has + // more than one predecessor, because Dst might be part of a loop which is + // only reachable from Src, in practice it is pointless since at the time + // GVN runs all such loops have preheaders, which means that Dst will have + // been changed to have only one predecessor, namely Src. + const BasicBlock *Pred = E.getEnd()->getSinglePredecessor(); + const BasicBlock *Src = E.getStart(); + assert((!Pred || Pred == Src) && "No edge between these basic blocks!"); + (void)Src; + return Pred != nullptr; +} + +void NewGVN::markUsersTouched(Value *V) { + // Now mark the users as touched. + for (auto *User : V->users()) { + assert(isa<Instruction>(User) && "Use of value not within an instruction?"); + TouchedInstructions.set(InstrDFS[User]); + } +} + +void NewGVN::markMemoryUsersTouched(MemoryAccess *MA) { + for (auto U : MA->users()) { + if (auto *MUD = dyn_cast<MemoryUseOrDef>(U)) + TouchedInstructions.set(InstrDFS[MUD->getMemoryInst()]); + else + TouchedInstructions.set(InstrDFS[U]); + } +} + +// Touch the instructions that need to be updated after a congruence class has a +// leader change, and mark changed values. +void NewGVN::markLeaderChangeTouched(CongruenceClass *CC) { + for (auto M : CC->Members) { + if (auto *I = dyn_cast<Instruction>(M)) + TouchedInstructions.set(InstrDFS[I]); + LeaderChanges.insert(M); + } +} + +// Move a value, currently in OldClass, to be part of NewClass +// Update OldClass for the move (including changing leaders, etc) +void NewGVN::moveValueToNewCongruenceClass(Instruction *I, + CongruenceClass *OldClass, + CongruenceClass *NewClass) { + DEBUG(dbgs() << "New congruence class for " << I << " is " << NewClass->ID + << "\n"); + + if (I == OldClass->NextLeader.first) + OldClass->NextLeader = {nullptr, ~0U}; + + // It's possible, though unlikely, for us to discover equivalences such + // that the current leader does not dominate the old one. + // This statistic tracks how often this happens. + // We assert on phi nodes when this happens, currently, for debugging, because + // we want to make sure we name phi node cycles properly. + if (isa<Instruction>(NewClass->RepLeader) && NewClass->RepLeader && + I != NewClass->RepLeader && + DT->properlyDominates( + I->getParent(), + cast<Instruction>(NewClass->RepLeader)->getParent())) { + ++NumGVNNotMostDominatingLeader; + assert(!isa<PHINode>(I) && + "New class for instruction should not be dominated by instruction"); + } + + if (NewClass->RepLeader != I) { + auto DFSNum = InstrDFS.lookup(I); + if (DFSNum < NewClass->NextLeader.second) + NewClass->NextLeader = {I, DFSNum}; + } + + OldClass->Members.erase(I); + NewClass->Members.insert(I); + if (isa<StoreInst>(I)) { + --OldClass->StoreCount; + assert(OldClass->StoreCount >= 0); + ++NewClass->StoreCount; + assert(NewClass->StoreCount > 0); + } + + ValueToClass[I] = NewClass; + // See if we destroyed the class or need to swap leaders. + if (OldClass->Members.empty() && OldClass != InitialClass) { + if (OldClass->DefiningExpr) { + OldClass->Dead = true; + DEBUG(dbgs() << "Erasing expression " << OldClass->DefiningExpr + << " from table\n"); + ExpressionToClass.erase(OldClass->DefiningExpr); + } + } else if (OldClass->RepLeader == I) { + // When the leader changes, the value numbering of + // everything may change due to symbolization changes, so we need to + // reprocess. + DEBUG(dbgs() << "Leader change!\n"); + ++NumGVNLeaderChanges; + // We don't need to sort members if there is only 1, and we don't care about + // sorting the initial class because everything either gets out of it or is + // unreachable. + if (OldClass->Members.size() == 1 || OldClass == InitialClass) { + OldClass->RepLeader = *(OldClass->Members.begin()); + } else if (OldClass->NextLeader.first) { + ++NumGVNAvoidedSortedLeaderChanges; + OldClass->RepLeader = OldClass->NextLeader.first; + OldClass->NextLeader = {nullptr, ~0U}; + } else { + ++NumGVNSortedLeaderChanges; + // TODO: If this ends up to slow, we can maintain a dual structure for + // member testing/insertion, or keep things mostly sorted, and sort only + // here, or .... + std::pair<Value *, unsigned> MinDFS = {nullptr, ~0U}; + for (const auto X : OldClass->Members) { + auto DFSNum = InstrDFS.lookup(X); + if (DFSNum < MinDFS.second) + MinDFS = {X, DFSNum}; + } + OldClass->RepLeader = MinDFS.first; + } + markLeaderChangeTouched(OldClass); + } +} + +// Perform congruence finding on a given value numbering expression. +void NewGVN::performCongruenceFinding(Instruction *I, const Expression *E) { + ValueToExpression[I] = E; + // This is guaranteed to return something, since it will at least find + // INITIAL. + + CongruenceClass *IClass = ValueToClass[I]; + assert(IClass && "Should have found a IClass"); + // Dead classes should have been eliminated from the mapping. + assert(!IClass->Dead && "Found a dead class"); + + CongruenceClass *EClass; + if (const auto *VE = dyn_cast<VariableExpression>(E)) { + EClass = ValueToClass[VE->getVariableValue()]; + } else { + auto lookupResult = ExpressionToClass.insert({E, nullptr}); + + // If it's not in the value table, create a new congruence class. + if (lookupResult.second) { + CongruenceClass *NewClass = createCongruenceClass(nullptr, E); + auto place = lookupResult.first; + place->second = NewClass; + + // Constants and variables should always be made the leader. + if (const auto *CE = dyn_cast<ConstantExpression>(E)) { + NewClass->RepLeader = CE->getConstantValue(); + } else if (const auto *SE = dyn_cast<StoreExpression>(E)) { + StoreInst *SI = SE->getStoreInst(); + NewClass->RepLeader = + lookupOperandLeader(SI->getValueOperand(), SI, SI->getParent()); + } else { + NewClass->RepLeader = I; + } + assert(!isa<VariableExpression>(E) && + "VariableExpression should have been handled already"); + + EClass = NewClass; + DEBUG(dbgs() << "Created new congruence class for " << *I + << " using expression " << *E << " at " << NewClass->ID + << " and leader " << *(NewClass->RepLeader) << "\n"); + DEBUG(dbgs() << "Hash value was " << E->getHashValue() << "\n"); + } else { + EClass = lookupResult.first->second; + if (isa<ConstantExpression>(E)) + assert(isa<Constant>(EClass->RepLeader) && + "Any class with a constant expression should have a " + "constant leader"); + + assert(EClass && "Somehow don't have an eclass"); + + assert(!EClass->Dead && "We accidentally looked up a dead class"); + } + } + bool ClassChanged = IClass != EClass; + bool LeaderChanged = LeaderChanges.erase(I); + if (ClassChanged || LeaderChanged) { + DEBUG(dbgs() << "Found class " << EClass->ID << " for expression " << E + << "\n"); + + if (ClassChanged) + moveValueToNewCongruenceClass(I, IClass, EClass); + markUsersTouched(I); + if (MemoryAccess *MA = MSSA->getMemoryAccess(I)) { + // If this is a MemoryDef, we need to update the equivalence table. If + // we determined the expression is congruent to a different memory + // state, use that different memory state. If we determined it didn't, + // we update that as well. Right now, we only support store + // expressions. + if (!isa<MemoryUse>(MA) && isa<StoreExpression>(E) && + EClass->Members.size() != 1) { + auto *DefAccess = cast<StoreExpression>(E)->getDefiningAccess(); + setMemoryAccessEquivTo(MA, DefAccess != MA ? DefAccess : nullptr); + } else { + setMemoryAccessEquivTo(MA, nullptr); + } + markMemoryUsersTouched(MA); + } + } else if (auto *SI = dyn_cast<StoreInst>(I)) { + // There is, sadly, one complicating thing for stores. Stores do not + // produce values, only consume them. However, in order to make loads and + // stores value number the same, we ignore the value operand of the store. + // But the value operand will still be the leader of our class, and thus, it + // may change. Because the store is a use, the store will get reprocessed, + // but nothing will change about it, and so nothing above will catch it + // (since the class will not change). In order to make sure everything ends + // up okay, we need to recheck the leader of the class. Since stores of + // different values value number differently due to different memorydefs, we + // are guaranteed the leader is always the same between stores in the same + // class. + DEBUG(dbgs() << "Checking store leader\n"); + auto ProperLeader = + lookupOperandLeader(SI->getValueOperand(), SI, SI->getParent()); + if (EClass->RepLeader != ProperLeader) { + DEBUG(dbgs() << "Store leader changed, fixing\n"); + EClass->RepLeader = ProperLeader; + markLeaderChangeTouched(EClass); + markMemoryUsersTouched(MSSA->getMemoryAccess(SI)); + } + } +} + +// Process the fact that Edge (from, to) is reachable, including marking +// any newly reachable blocks and instructions for processing. +void NewGVN::updateReachableEdge(BasicBlock *From, BasicBlock *To) { + // Check if the Edge was reachable before. + if (ReachableEdges.insert({From, To}).second) { + // If this block wasn't reachable before, all instructions are touched. + if (ReachableBlocks.insert(To).second) { + DEBUG(dbgs() << "Block " << getBlockName(To) << " marked reachable\n"); + const auto &InstRange = BlockInstRange.lookup(To); + TouchedInstructions.set(InstRange.first, InstRange.second); + } else { + DEBUG(dbgs() << "Block " << getBlockName(To) + << " was reachable, but new edge {" << getBlockName(From) + << "," << getBlockName(To) << "} to it found\n"); + + // We've made an edge reachable to an existing block, which may + // impact predicates. Otherwise, only mark the phi nodes as touched, as + // they are the only thing that depend on new edges. Anything using their + // values will get propagated to if necessary. + if (MemoryAccess *MemPhi = MSSA->getMemoryAccess(To)) + TouchedInstructions.set(InstrDFS[MemPhi]); + + auto BI = To->begin(); + while (isa<PHINode>(BI)) { + TouchedInstructions.set(InstrDFS[&*BI]); + ++BI; + } + } + } +} + +// Given a predicate condition (from a switch, cmp, or whatever) and a block, +// see if we know some constant value for it already. +Value *NewGVN::findConditionEquivalence(Value *Cond, BasicBlock *B) const { + auto Result = lookupOperandLeader(Cond, nullptr, B); + if (isa<Constant>(Result)) + return Result; + return nullptr; +} + +// Process the outgoing edges of a block for reachability. +void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) { + // Evaluate reachability of terminator instruction. + BranchInst *BR; + if ((BR = dyn_cast<BranchInst>(TI)) && BR->isConditional()) { + Value *Cond = BR->getCondition(); + Value *CondEvaluated = findConditionEquivalence(Cond, B); + if (!CondEvaluated) { + if (auto *I = dyn_cast<Instruction>(Cond)) { + const Expression *E = createExpression(I, B); + if (const auto *CE = dyn_cast<ConstantExpression>(E)) { + CondEvaluated = CE->getConstantValue(); + } + } else if (isa<ConstantInt>(Cond)) { + CondEvaluated = Cond; + } + } + ConstantInt *CI; + BasicBlock *TrueSucc = BR->getSuccessor(0); + BasicBlock *FalseSucc = BR->getSuccessor(1); + if (CondEvaluated && (CI = dyn_cast<ConstantInt>(CondEvaluated))) { + if (CI->isOne()) { + DEBUG(dbgs() << "Condition for Terminator " << *TI + << " evaluated to true\n"); + updateReachableEdge(B, TrueSucc); + } else if (CI->isZero()) { + DEBUG(dbgs() << "Condition for Terminator " << *TI + << " evaluated to false\n"); + updateReachableEdge(B, FalseSucc); + } + } else { + updateReachableEdge(B, TrueSucc); + updateReachableEdge(B, FalseSucc); + } + } else if (auto *SI = dyn_cast<SwitchInst>(TI)) { + // For switches, propagate the case values into the case + // destinations. + + // Remember how many outgoing edges there are to every successor. + SmallDenseMap<BasicBlock *, unsigned, 16> SwitchEdges; + + Value *SwitchCond = SI->getCondition(); + Value *CondEvaluated = findConditionEquivalence(SwitchCond, B); + // See if we were able to turn this switch statement into a constant. + if (CondEvaluated && isa<ConstantInt>(CondEvaluated)) { + auto *CondVal = cast<ConstantInt>(CondEvaluated); + // We should be able to get case value for this. + auto CaseVal = SI->findCaseValue(CondVal); + if (CaseVal.getCaseSuccessor() == SI->getDefaultDest()) { + // We proved the value is outside of the range of the case. + // We can't do anything other than mark the default dest as reachable, + // and go home. + updateReachableEdge(B, SI->getDefaultDest()); + return; + } + // Now get where it goes and mark it reachable. + BasicBlock *TargetBlock = CaseVal.getCaseSuccessor(); + updateReachableEdge(B, TargetBlock); + } else { + for (unsigned i = 0, e = SI->getNumSuccessors(); i != e; ++i) { + BasicBlock *TargetBlock = SI->getSuccessor(i); + ++SwitchEdges[TargetBlock]; + updateReachableEdge(B, TargetBlock); + } + } + } else { + // Otherwise this is either unconditional, or a type we have no + // idea about. Just mark successors as reachable. + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { + BasicBlock *TargetBlock = TI->getSuccessor(i); + updateReachableEdge(B, TargetBlock); + } + + // This also may be a memory defining terminator, in which case, set it + // equivalent to nothing. + if (MemoryAccess *MA = MSSA->getMemoryAccess(TI)) + setMemoryAccessEquivTo(MA, nullptr); + } +} + +// The algorithm initially places the values of the routine in the INITIAL +// congruence +// class. The leader of INITIAL is the undetermined value `TOP`. +// When the algorithm has finished, values still in INITIAL are unreachable. +void NewGVN::initializeCongruenceClasses(Function &F) { + // FIXME now i can't remember why this is 2 + NextCongruenceNum = 2; + // Initialize all other instructions to be in INITIAL class. + CongruenceClass::MemberSet InitialValues; + InitialClass = createCongruenceClass(nullptr, nullptr); + for (auto &B : F) { + if (auto *MP = MSSA->getMemoryAccess(&B)) + MemoryAccessEquiv.insert({MP, MSSA->getLiveOnEntryDef()}); + + for (auto &I : B) { + InitialValues.insert(&I); + ValueToClass[&I] = InitialClass; + // All memory accesses are equivalent to live on entry to start. They must + // be initialized to something so that initial changes are noticed. For + // the maximal answer, we initialize them all to be the same as + // liveOnEntry. Note that to save time, we only initialize the + // MemoryDef's for stores and all MemoryPhis to be equal. Right now, no + // other expression can generate a memory equivalence. If we start + // handling memcpy/etc, we can expand this. + if (isa<StoreInst>(&I)) { + MemoryAccessEquiv.insert( + {MSSA->getMemoryAccess(&I), MSSA->getLiveOnEntryDef()}); + ++InitialClass->StoreCount; + assert(InitialClass->StoreCount > 0); + } + } + } + InitialClass->Members.swap(InitialValues); + + // Initialize arguments to be in their own unique congruence classes + for (auto &FA : F.args()) + createSingletonCongruenceClass(&FA); +} + +void NewGVN::cleanupTables() { + for (unsigned i = 0, e = CongruenceClasses.size(); i != e; ++i) { + DEBUG(dbgs() << "Congruence class " << CongruenceClasses[i]->ID << " has " + << CongruenceClasses[i]->Members.size() << " members\n"); + // Make sure we delete the congruence class (probably worth switching to + // a unique_ptr at some point. + delete CongruenceClasses[i]; + CongruenceClasses[i] = nullptr; + } + + ValueToClass.clear(); + ArgRecycler.clear(ExpressionAllocator); + ExpressionAllocator.Reset(); + CongruenceClasses.clear(); + ExpressionToClass.clear(); + ValueToExpression.clear(); + ReachableBlocks.clear(); + ReachableEdges.clear(); +#ifndef NDEBUG + ProcessedCount.clear(); +#endif + DFSDomMap.clear(); + InstrDFS.clear(); + InstructionsToErase.clear(); + + DFSToInstr.clear(); + BlockInstRange.clear(); + TouchedInstructions.clear(); + DominatedInstRange.clear(); + MemoryAccessEquiv.clear(); +} + +std::pair<unsigned, unsigned> NewGVN::assignDFSNumbers(BasicBlock *B, + unsigned Start) { + unsigned End = Start; + if (MemoryAccess *MemPhi = MSSA->getMemoryAccess(B)) { + InstrDFS[MemPhi] = End++; + DFSToInstr.emplace_back(MemPhi); + } + + for (auto &I : *B) { + InstrDFS[&I] = End++; + DFSToInstr.emplace_back(&I); + } + + // All of the range functions taken half-open ranges (open on the end side). + // So we do not subtract one from count, because at this point it is one + // greater than the last instruction. + return std::make_pair(Start, End); +} + +void NewGVN::updateProcessedCount(Value *V) { +#ifndef NDEBUG + if (ProcessedCount.count(V) == 0) { + ProcessedCount.insert({V, 1}); + } else { + ProcessedCount[V] += 1; + assert(ProcessedCount[V] < 100 && + "Seem to have processed the same Value a lot"); + } +#endif +} +// Evaluate MemoryPhi nodes symbolically, just like PHI nodes +void NewGVN::valueNumberMemoryPhi(MemoryPhi *MP) { + // If all the arguments are the same, the MemoryPhi has the same value as the + // argument. + // Filter out unreachable blocks from our operands. + auto Filtered = make_filter_range(MP->operands(), [&](const Use &U) { + return ReachableBlocks.count(MP->getIncomingBlock(U)); + }); + + assert(Filtered.begin() != Filtered.end() && + "We should not be processing a MemoryPhi in a completely " + "unreachable block"); + + // Transform the remaining operands into operand leaders. + // FIXME: mapped_iterator should have a range version. + auto LookupFunc = [&](const Use &U) { + return lookupMemoryAccessEquiv(cast<MemoryAccess>(U)); + }; + auto MappedBegin = map_iterator(Filtered.begin(), LookupFunc); + auto MappedEnd = map_iterator(Filtered.end(), LookupFunc); + + // and now check if all the elements are equal. + // Sadly, we can't use std::equals since these are random access iterators. + MemoryAccess *AllSameValue = *MappedBegin; + ++MappedBegin; + bool AllEqual = std::all_of( + MappedBegin, MappedEnd, + [&AllSameValue](const MemoryAccess *V) { return V == AllSameValue; }); + + if (AllEqual) + DEBUG(dbgs() << "Memory Phi value numbered to " << *AllSameValue << "\n"); + else + DEBUG(dbgs() << "Memory Phi value numbered to itself\n"); + + if (setMemoryAccessEquivTo(MP, AllEqual ? AllSameValue : nullptr)) + markMemoryUsersTouched(MP); +} + +// Value number a single instruction, symbolically evaluating, performing +// congruence finding, and updating mappings. +void NewGVN::valueNumberInstruction(Instruction *I) { + DEBUG(dbgs() << "Processing instruction " << *I << "\n"); + if (isInstructionTriviallyDead(I, TLI)) { + DEBUG(dbgs() << "Skipping unused instruction\n"); + markInstructionForDeletion(I); + return; + } + if (!I->isTerminator()) { + const auto *Symbolized = performSymbolicEvaluation(I, I->getParent()); + // If we couldn't come up with a symbolic expression, use the unknown + // expression + if (Symbolized == nullptr) + Symbolized = createUnknownExpression(I); + performCongruenceFinding(I, Symbolized); + } else { + // Handle terminators that return values. All of them produce values we + // don't currently understand. + if (!I->getType()->isVoidTy()) { + auto *Symbolized = createUnknownExpression(I); + performCongruenceFinding(I, Symbolized); + } + processOutgoingEdges(dyn_cast<TerminatorInst>(I), I->getParent()); + } +} + +// Check if there is a path, using single or equal argument phi nodes, from +// First to Second. +bool NewGVN::singleReachablePHIPath(const MemoryAccess *First, + const MemoryAccess *Second) const { + if (First == Second) + return true; + + if (auto *FirstDef = dyn_cast<MemoryUseOrDef>(First)) { + auto *DefAccess = FirstDef->getDefiningAccess(); + return singleReachablePHIPath(DefAccess, Second); + } else { + auto *MP = cast<MemoryPhi>(First); + auto ReachableOperandPred = [&](const Use &U) { + return ReachableBlocks.count(MP->getIncomingBlock(U)); + }; + auto FilteredPhiArgs = + make_filter_range(MP->operands(), ReachableOperandPred); + SmallVector<const Value *, 32> OperandList; + std::copy(FilteredPhiArgs.begin(), FilteredPhiArgs.end(), + std::back_inserter(OperandList)); + bool Okay = OperandList.size() == 1; + if (!Okay) + Okay = std::equal(OperandList.begin(), OperandList.end(), + OperandList.begin()); + if (Okay) + return singleReachablePHIPath(cast<MemoryAccess>(OperandList[0]), Second); + return false; + } +} + +// Verify the that the memory equivalence table makes sense relative to the +// congruence classes. Note that this checking is not perfect, and is currently +// subject to very rare false negatives. It is only useful for testing/debugging. +void NewGVN::verifyMemoryCongruency() const { + // Anything equivalent in the memory access table should be in the same + // congruence class. + + // Filter out the unreachable and trivially dead entries, because they may + // never have been updated if the instructions were not processed. + auto ReachableAccessPred = + [&](const std::pair<const MemoryAccess *, MemoryAccess *> Pair) { + bool Result = ReachableBlocks.count(Pair.first->getBlock()); + if (!Result) + return false; + if (auto *MemDef = dyn_cast<MemoryDef>(Pair.first)) + return !isInstructionTriviallyDead(MemDef->getMemoryInst()); + return true; + }; + + auto Filtered = make_filter_range(MemoryAccessEquiv, ReachableAccessPred); + for (auto KV : Filtered) { + assert(KV.first != KV.second && + "We added a useless equivalence to the memory equivalence table"); + // Unreachable instructions may not have changed because we never process + // them. + if (!ReachableBlocks.count(KV.first->getBlock())) + continue; + if (auto *FirstMUD = dyn_cast<MemoryUseOrDef>(KV.first)) { + auto *SecondMUD = dyn_cast<MemoryUseOrDef>(KV.second); + if (FirstMUD && SecondMUD) + assert((singleReachablePHIPath(FirstMUD, SecondMUD) || + ValueToClass.lookup(FirstMUD->getMemoryInst()) == + ValueToClass.lookup(SecondMUD->getMemoryInst())) && + "The instructions for these memory operations should have " + "been in the same congruence class or reachable through" + "a single argument phi"); + } else if (auto *FirstMP = dyn_cast<MemoryPhi>(KV.first)) { + + // We can only sanely verify that MemoryDefs in the operand list all have + // the same class. + auto ReachableOperandPred = [&](const Use &U) { + return ReachableBlocks.count(FirstMP->getIncomingBlock(U)) && + isa<MemoryDef>(U); + + }; + // All arguments should in the same class, ignoring unreachable arguments + auto FilteredPhiArgs = + make_filter_range(FirstMP->operands(), ReachableOperandPred); + SmallVector<const CongruenceClass *, 16> PhiOpClasses; + std::transform(FilteredPhiArgs.begin(), FilteredPhiArgs.end(), + std::back_inserter(PhiOpClasses), [&](const Use &U) { + const MemoryDef *MD = cast<MemoryDef>(U); + return ValueToClass.lookup(MD->getMemoryInst()); + }); + assert(std::equal(PhiOpClasses.begin(), PhiOpClasses.end(), + PhiOpClasses.begin()) && + "All MemoryPhi arguments should be in the same class"); + } + } +} + +// This is the main transformation entry point. +bool NewGVN::runGVN(Function &F, DominatorTree *_DT, AssumptionCache *_AC, + TargetLibraryInfo *_TLI, AliasAnalysis *_AA, + MemorySSA *_MSSA) { + bool Changed = false; + DT = _DT; + AC = _AC; + TLI = _TLI; + AA = _AA; + MSSA = _MSSA; + DL = &F.getParent()->getDataLayout(); + MSSAWalker = MSSA->getWalker(); + + // Count number of instructions for sizing of hash tables, and come + // up with a global dfs numbering for instructions. + unsigned ICount = 1; + // Add an empty instruction to account for the fact that we start at 1 + DFSToInstr.emplace_back(nullptr); + // Note: We want RPO traversal of the blocks, which is not quite the same as + // dominator tree order, particularly with regard whether backedges get + // visited first or second, given a block with multiple successors. + // If we visit in the wrong order, we will end up performing N times as many + // iterations. + // The dominator tree does guarantee that, for a given dom tree node, it's + // parent must occur before it in the RPO ordering. Thus, we only need to sort + // the siblings. + DenseMap<const DomTreeNode *, unsigned> RPOOrdering; + ReversePostOrderTraversal<Function *> RPOT(&F); + unsigned Counter = 0; + for (auto &B : RPOT) { + auto *Node = DT->getNode(B); + assert(Node && "RPO and Dominator tree should have same reachability"); + RPOOrdering[Node] = ++Counter; + } + // Sort dominator tree children arrays into RPO. + for (auto &B : RPOT) { + auto *Node = DT->getNode(B); + if (Node->getChildren().size() > 1) + std::sort(Node->begin(), Node->end(), + [&RPOOrdering](const DomTreeNode *A, const DomTreeNode *B) { + return RPOOrdering[A] < RPOOrdering[B]; + }); + } + + // Now a standard depth first ordering of the domtree is equivalent to RPO. + auto DFI = df_begin(DT->getRootNode()); + for (auto DFE = df_end(DT->getRootNode()); DFI != DFE; ++DFI) { + BasicBlock *B = DFI->getBlock(); + const auto &BlockRange = assignDFSNumbers(B, ICount); + BlockInstRange.insert({B, BlockRange}); + ICount += BlockRange.second - BlockRange.first; + } + + // Handle forward unreachable blocks and figure out which blocks + // have single preds. + for (auto &B : F) { + // Assign numbers to unreachable blocks. + if (!DFI.nodeVisited(DT->getNode(&B))) { + const auto &BlockRange = assignDFSNumbers(&B, ICount); + BlockInstRange.insert({&B, BlockRange}); + ICount += BlockRange.second - BlockRange.first; + } + } + + TouchedInstructions.resize(ICount); + DominatedInstRange.reserve(F.size()); + // Ensure we don't end up resizing the expressionToClass map, as + // that can be quite expensive. At most, we have one expression per + // instruction. + ExpressionToClass.reserve(ICount); + + // Initialize the touched instructions to include the entry block. + const auto &InstRange = BlockInstRange.lookup(&F.getEntryBlock()); + TouchedInstructions.set(InstRange.first, InstRange.second); + ReachableBlocks.insert(&F.getEntryBlock()); + + initializeCongruenceClasses(F); + + unsigned int Iterations = 0; + // We start out in the entry block. + BasicBlock *LastBlock = &F.getEntryBlock(); + while (TouchedInstructions.any()) { + ++Iterations; + // Walk through all the instructions in all the blocks in RPO. + for (int InstrNum = TouchedInstructions.find_first(); InstrNum != -1; + InstrNum = TouchedInstructions.find_next(InstrNum)) { + assert(InstrNum != 0 && "Bit 0 should never be set, something touched an " + "instruction not in the lookup table"); + Value *V = DFSToInstr[InstrNum]; + BasicBlock *CurrBlock = nullptr; + + if (auto *I = dyn_cast<Instruction>(V)) + CurrBlock = I->getParent(); + else if (auto *MP = dyn_cast<MemoryPhi>(V)) + CurrBlock = MP->getBlock(); + else + llvm_unreachable("DFSToInstr gave us an unknown type of instruction"); + + // If we hit a new block, do reachability processing. + if (CurrBlock != LastBlock) { + LastBlock = CurrBlock; + bool BlockReachable = ReachableBlocks.count(CurrBlock); + const auto &CurrInstRange = BlockInstRange.lookup(CurrBlock); + + // If it's not reachable, erase any touched instructions and move on. + if (!BlockReachable) { + TouchedInstructions.reset(CurrInstRange.first, CurrInstRange.second); + DEBUG(dbgs() << "Skipping instructions in block " + << getBlockName(CurrBlock) + << " because it is unreachable\n"); + continue; + } + updateProcessedCount(CurrBlock); + } + + if (auto *MP = dyn_cast<MemoryPhi>(V)) { + DEBUG(dbgs() << "Processing MemoryPhi " << *MP << "\n"); + valueNumberMemoryPhi(MP); + } else if (auto *I = dyn_cast<Instruction>(V)) { + valueNumberInstruction(I); + } else { + llvm_unreachable("Should have been a MemoryPhi or Instruction"); + } + updateProcessedCount(V); + // Reset after processing (because we may mark ourselves as touched when + // we propagate equalities). + TouchedInstructions.reset(InstrNum); + } + } + NumGVNMaxIterations = std::max(NumGVNMaxIterations.getValue(), Iterations); +#ifndef NDEBUG + verifyMemoryCongruency(); +#endif + Changed |= eliminateInstructions(F); + + // Delete all instructions marked for deletion. + for (Instruction *ToErase : InstructionsToErase) { + if (!ToErase->use_empty()) + ToErase->replaceAllUsesWith(UndefValue::get(ToErase->getType())); + + ToErase->eraseFromParent(); + } + + // Delete all unreachable blocks. + auto UnreachableBlockPred = [&](const BasicBlock &BB) { + return !ReachableBlocks.count(&BB); + }; + + for (auto &BB : make_filter_range(F, UnreachableBlockPred)) { + DEBUG(dbgs() << "We believe block " << getBlockName(&BB) + << " is unreachable\n"); + deleteInstructionsInBlock(&BB); + Changed = true; + } + + cleanupTables(); + return Changed; +} + +bool NewGVN::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + return runGVN(F, &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), + &getAnalysis<AAResultsWrapperPass>().getAAResults(), + &getAnalysis<MemorySSAWrapperPass>().getMSSA()); +} + +PreservedAnalyses NewGVNPass::run(Function &F, AnalysisManager<Function> &AM) { + NewGVN Impl; + + // Apparently the order in which we get these results matter for + // the old GVN (see Chandler's comment in GVN.cpp). I'll keep + // the same order here, just in case. + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + auto &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); + bool Changed = Impl.runGVN(F, &DT, &AC, &TLI, &AA, &MSSA); + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<GlobalsAA>(); + return PA; +} + +// Return true if V is a value that will always be available (IE can +// be placed anywhere) in the function. We don't do globals here +// because they are often worse to put in place. +// TODO: Separate cost from availability +static bool alwaysAvailable(Value *V) { + return isa<Constant>(V) || isa<Argument>(V); +} + +// Get the basic block from an instruction/value. +static BasicBlock *getBlockForValue(Value *V) { + if (auto *I = dyn_cast<Instruction>(V)) + return I->getParent(); + return nullptr; +} + +struct NewGVN::ValueDFS { + int DFSIn = 0; + int DFSOut = 0; + int LocalNum = 0; + // Only one of these will be set. + Value *Val = nullptr; + Use *U = nullptr; + + bool operator<(const ValueDFS &Other) const { + // It's not enough that any given field be less than - we have sets + // of fields that need to be evaluated together to give a proper ordering. + // For example, if you have; + // DFS (1, 3) + // Val 0 + // DFS (1, 2) + // Val 50 + // We want the second to be less than the first, but if we just go field + // by field, we will get to Val 0 < Val 50 and say the first is less than + // the second. We only want it to be less than if the DFS orders are equal. + // + // Each LLVM instruction only produces one value, and thus the lowest-level + // differentiator that really matters for the stack (and what we use as as a + // replacement) is the local dfs number. + // Everything else in the structure is instruction level, and only affects + // the order in which we will replace operands of a given instruction. + // + // For a given instruction (IE things with equal dfsin, dfsout, localnum), + // the order of replacement of uses does not matter. + // IE given, + // a = 5 + // b = a + a + // When you hit b, you will have two valuedfs with the same dfsin, out, and + // localnum. + // The .val will be the same as well. + // The .u's will be different. + // You will replace both, and it does not matter what order you replace them + // in (IE whether you replace operand 2, then operand 1, or operand 1, then + // operand 2). + // Similarly for the case of same dfsin, dfsout, localnum, but different + // .val's + // a = 5 + // b = 6 + // c = a + b + // in c, we will a valuedfs for a, and one for b,with everything the same + // but .val and .u. + // It does not matter what order we replace these operands in. + // You will always end up with the same IR, and this is guaranteed. + return std::tie(DFSIn, DFSOut, LocalNum, Val, U) < + std::tie(Other.DFSIn, Other.DFSOut, Other.LocalNum, Other.Val, + Other.U); + } +}; + +void NewGVN::convertDenseToDFSOrdered( + CongruenceClass::MemberSet &Dense, + SmallVectorImpl<ValueDFS> &DFSOrderedSet) { + for (auto D : Dense) { + // First add the value. + BasicBlock *BB = getBlockForValue(D); + // Constants are handled prior to ever calling this function, so + // we should only be left with instructions as members. + assert(BB && "Should have figured out a basic block for value"); + ValueDFS VD; + + std::pair<int, int> DFSPair = DFSDomMap[BB]; + assert(DFSPair.first != -1 && DFSPair.second != -1 && "Invalid DFS Pair"); + VD.DFSIn = DFSPair.first; + VD.DFSOut = DFSPair.second; + VD.Val = D; + // If it's an instruction, use the real local dfs number. + if (auto *I = dyn_cast<Instruction>(D)) + VD.LocalNum = InstrDFS[I]; + else + llvm_unreachable("Should have been an instruction"); + + DFSOrderedSet.emplace_back(VD); + + // Now add the users. + for (auto &U : D->uses()) { + if (auto *I = dyn_cast<Instruction>(U.getUser())) { + ValueDFS VD; + // Put the phi node uses in the incoming block. + BasicBlock *IBlock; + if (auto *P = dyn_cast<PHINode>(I)) { + IBlock = P->getIncomingBlock(U); + // Make phi node users appear last in the incoming block + // they are from. + VD.LocalNum = InstrDFS.size() + 1; + } else { + IBlock = I->getParent(); + VD.LocalNum = InstrDFS[I]; + } + std::pair<int, int> DFSPair = DFSDomMap[IBlock]; + VD.DFSIn = DFSPair.first; + VD.DFSOut = DFSPair.second; + VD.U = &U; + DFSOrderedSet.emplace_back(VD); + } + } + } +} + +static void patchReplacementInstruction(Instruction *I, Value *Repl) { + // Patch the replacement so that it is not more restrictive than the value + // being replaced. + auto *Op = dyn_cast<BinaryOperator>(I); + auto *ReplOp = dyn_cast<BinaryOperator>(Repl); + + if (Op && ReplOp) + ReplOp->andIRFlags(Op); + + if (auto *ReplInst = dyn_cast<Instruction>(Repl)) { + // FIXME: If both the original and replacement value are part of the + // same control-flow region (meaning that the execution of one + // guarentees the executation of the other), then we can combine the + // noalias scopes here and do better than the general conservative + // answer used in combineMetadata(). + + // In general, GVN unifies expressions over different control-flow + // regions, and so we need a conservative combination of the noalias + // scopes. + unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, LLVMContext::MD_range, + LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, + LLVMContext::MD_invariant_group}; + combineMetadata(ReplInst, I, KnownIDs); + } +} + +static void patchAndReplaceAllUsesWith(Instruction *I, Value *Repl) { + patchReplacementInstruction(I, Repl); + I->replaceAllUsesWith(Repl); +} + +void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) { + DEBUG(dbgs() << " BasicBlock Dead:" << *BB); + ++NumGVNBlocksDeleted; + + // Check to see if there are non-terminating instructions to delete. + if (isa<TerminatorInst>(BB->begin())) + return; + + // Delete the instructions backwards, as it has a reduced likelihood of having + // to update as many def-use and use-def chains. Start after the terminator. + auto StartPoint = BB->rbegin(); + ++StartPoint; + // Note that we explicitly recalculate BB->rend() on each iteration, + // as it may change when we remove the first instruction. + for (BasicBlock::reverse_iterator I(StartPoint); I != BB->rend();) { + Instruction &Inst = *I++; + if (!Inst.use_empty()) + Inst.replaceAllUsesWith(UndefValue::get(Inst.getType())); + if (isa<LandingPadInst>(Inst)) + continue; + + Inst.eraseFromParent(); + ++NumGVNInstrDeleted; + } +} + +void NewGVN::markInstructionForDeletion(Instruction *I) { + DEBUG(dbgs() << "Marking " << *I << " for deletion\n"); + InstructionsToErase.insert(I); +} + +void NewGVN::replaceInstruction(Instruction *I, Value *V) { + + DEBUG(dbgs() << "Replacing " << *I << " with " << *V << "\n"); + patchAndReplaceAllUsesWith(I, V); + // We save the actual erasing to avoid invalidating memory + // dependencies until we are done with everything. + markInstructionForDeletion(I); +} + +namespace { + +// This is a stack that contains both the value and dfs info of where +// that value is valid. +class ValueDFSStack { +public: + Value *back() const { return ValueStack.back(); } + std::pair<int, int> dfs_back() const { return DFSStack.back(); } + + void push_back(Value *V, int DFSIn, int DFSOut) { + ValueStack.emplace_back(V); + DFSStack.emplace_back(DFSIn, DFSOut); + } + bool empty() const { return DFSStack.empty(); } + bool isInScope(int DFSIn, int DFSOut) const { + if (empty()) + return false; + return DFSIn >= DFSStack.back().first && DFSOut <= DFSStack.back().second; + } + + void popUntilDFSScope(int DFSIn, int DFSOut) { + + // These two should always be in sync at this point. + assert(ValueStack.size() == DFSStack.size() && + "Mismatch between ValueStack and DFSStack"); + while ( + !DFSStack.empty() && + !(DFSIn >= DFSStack.back().first && DFSOut <= DFSStack.back().second)) { + DFSStack.pop_back(); + ValueStack.pop_back(); + } + } + +private: + SmallVector<Value *, 8> ValueStack; + SmallVector<std::pair<int, int>, 8> DFSStack; +}; +} + +bool NewGVN::eliminateInstructions(Function &F) { + // This is a non-standard eliminator. The normal way to eliminate is + // to walk the dominator tree in order, keeping track of available + // values, and eliminating them. However, this is mildly + // pointless. It requires doing lookups on every instruction, + // regardless of whether we will ever eliminate it. For + // instructions part of most singleton congruence classes, we know we + // will never eliminate them. + + // Instead, this eliminator looks at the congruence classes directly, sorts + // them into a DFS ordering of the dominator tree, and then we just + // perform elimination straight on the sets by walking the congruence + // class member uses in order, and eliminate the ones dominated by the + // last member. This is worst case O(E log E) where E = number of + // instructions in a single congruence class. In theory, this is all + // instructions. In practice, it is much faster, as most instructions are + // either in singleton congruence classes or can't possibly be eliminated + // anyway (if there are no overlapping DFS ranges in class). + // When we find something not dominated, it becomes the new leader + // for elimination purposes. + // TODO: If we wanted to be faster, We could remove any members with no + // overlapping ranges while sorting, as we will never eliminate anything + // with those members, as they don't dominate anything else in our set. + + bool AnythingReplaced = false; + + // Since we are going to walk the domtree anyway, and we can't guarantee the + // DFS numbers are updated, we compute some ourselves. + DT->updateDFSNumbers(); + + for (auto &B : F) { + if (!ReachableBlocks.count(&B)) { + for (const auto S : successors(&B)) { + for (auto II = S->begin(); isa<PHINode>(II); ++II) { + auto &Phi = cast<PHINode>(*II); + DEBUG(dbgs() << "Replacing incoming value of " << *II << " for block " + << getBlockName(&B) + << " with undef due to it being unreachable\n"); + for (auto &Operand : Phi.incoming_values()) + if (Phi.getIncomingBlock(Operand) == &B) + Operand.set(UndefValue::get(Phi.getType())); + } + } + } + DomTreeNode *Node = DT->getNode(&B); + if (Node) + DFSDomMap[&B] = {Node->getDFSNumIn(), Node->getDFSNumOut()}; + } + + for (CongruenceClass *CC : CongruenceClasses) { + // FIXME: We should eventually be able to replace everything still + // in the initial class with undef, as they should be unreachable. + // Right now, initial still contains some things we skip value + // numbering of (UNREACHABLE's, for example). + if (CC == InitialClass || CC->Dead) + continue; + assert(CC->RepLeader && "We should have had a leader"); + + // If this is a leader that is always available, and it's a + // constant or has no equivalences, just replace everything with + // it. We then update the congruence class with whatever members + // are left. + if (alwaysAvailable(CC->RepLeader)) { + SmallPtrSet<Value *, 4> MembersLeft; + for (auto M : CC->Members) { + + Value *Member = M; + + // Void things have no uses we can replace. + if (Member == CC->RepLeader || Member->getType()->isVoidTy()) { + MembersLeft.insert(Member); + continue; + } + + DEBUG(dbgs() << "Found replacement " << *(CC->RepLeader) << " for " + << *Member << "\n"); + // Due to equality propagation, these may not always be + // instructions, they may be real values. We don't really + // care about trying to replace the non-instructions. + if (auto *I = dyn_cast<Instruction>(Member)) { + assert(CC->RepLeader != I && + "About to accidentally remove our leader"); + replaceInstruction(I, CC->RepLeader); + AnythingReplaced = true; + + continue; + } else { + MembersLeft.insert(I); + } + } + CC->Members.swap(MembersLeft); + + } else { + DEBUG(dbgs() << "Eliminating in congruence class " << CC->ID << "\n"); + // If this is a singleton, we can skip it. + if (CC->Members.size() != 1) { + + // This is a stack because equality replacement/etc may place + // constants in the middle of the member list, and we want to use + // those constant values in preference to the current leader, over + // the scope of those constants. + ValueDFSStack EliminationStack; + + // Convert the members to DFS ordered sets and then merge them. + SmallVector<ValueDFS, 8> DFSOrderedSet; + convertDenseToDFSOrdered(CC->Members, DFSOrderedSet); + + // Sort the whole thing. + std::sort(DFSOrderedSet.begin(), DFSOrderedSet.end()); + + for (auto &VD : DFSOrderedSet) { + int MemberDFSIn = VD.DFSIn; + int MemberDFSOut = VD.DFSOut; + Value *Member = VD.Val; + Use *MemberUse = VD.U; + + if (Member) { + // We ignore void things because we can't get a value from them. + // FIXME: We could actually use this to kill dead stores that are + // dominated by equivalent earlier stores. + if (Member->getType()->isVoidTy()) + continue; + } + + if (EliminationStack.empty()) { + DEBUG(dbgs() << "Elimination Stack is empty\n"); + } else { + DEBUG(dbgs() << "Elimination Stack Top DFS numbers are (" + << EliminationStack.dfs_back().first << "," + << EliminationStack.dfs_back().second << ")\n"); + } + + DEBUG(dbgs() << "Current DFS numbers are (" << MemberDFSIn << "," + << MemberDFSOut << ")\n"); + // First, we see if we are out of scope or empty. If so, + // and there equivalences, we try to replace the top of + // stack with equivalences (if it's on the stack, it must + // not have been eliminated yet). + // Then we synchronize to our current scope, by + // popping until we are back within a DFS scope that + // dominates the current member. + // Then, what happens depends on a few factors + // If the stack is now empty, we need to push + // If we have a constant or a local equivalence we want to + // start using, we also push. + // Otherwise, we walk along, processing members who are + // dominated by this scope, and eliminate them. + bool ShouldPush = + Member && (EliminationStack.empty() || isa<Constant>(Member)); + bool OutOfScope = + !EliminationStack.isInScope(MemberDFSIn, MemberDFSOut); + + if (OutOfScope || ShouldPush) { + // Sync to our current scope. + EliminationStack.popUntilDFSScope(MemberDFSIn, MemberDFSOut); + ShouldPush |= Member && EliminationStack.empty(); + if (ShouldPush) { + EliminationStack.push_back(Member, MemberDFSIn, MemberDFSOut); + } + } + + // If we get to this point, and the stack is empty we must have a use + // with nothing we can use to eliminate it, just skip it. + if (EliminationStack.empty()) + continue; + + // Skip the Value's, we only want to eliminate on their uses. + if (Member) + continue; + Value *Result = EliminationStack.back(); + + // Don't replace our existing users with ourselves. + if (MemberUse->get() == Result) + continue; + + DEBUG(dbgs() << "Found replacement " << *Result << " for " + << *MemberUse->get() << " in " << *(MemberUse->getUser()) + << "\n"); + + // If we replaced something in an instruction, handle the patching of + // metadata. + if (auto *ReplacedInst = dyn_cast<Instruction>(MemberUse->get())) + patchReplacementInstruction(ReplacedInst, Result); + + assert(isa<Instruction>(MemberUse->getUser())); + MemberUse->set(Result); + AnythingReplaced = true; + } + } + } + + // Cleanup the congruence class. + SmallPtrSet<Value *, 4> MembersLeft; + for (Value *Member : CC->Members) { + if (Member->getType()->isVoidTy()) { + MembersLeft.insert(Member); + continue; + } + + if (auto *MemberInst = dyn_cast<Instruction>(Member)) { + if (isInstructionTriviallyDead(MemberInst)) { + // TODO: Don't mark loads of undefs. + markInstructionForDeletion(MemberInst); + continue; + } + } + MembersLeft.insert(Member); + } + CC->Members.swap(MembersLeft); + } + + return AnythingReplaced; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/contrib/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp index c4b3e34..1a7ddc9 100644 --- a/contrib/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp @@ -123,7 +123,7 @@ static bool runPartiallyInlineLibCalls(Function &F, TargetLibraryInfo *TLI, } PreservedAnalyses -PartiallyInlineLibCallsPass::run(Function &F, AnalysisManager<Function> &AM) { +PartiallyInlineLibCallsPass::run(Function &F, FunctionAnalysisManager &AM) { auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &TTI = AM.getResult<TargetIRAnalysis>(F); if (!runPartiallyInlineLibCalls(F, &TLI, &TTI)) diff --git a/contrib/llvm/lib/Transforms/Scalar/Reassociate.cpp b/contrib/llvm/lib/Transforms/Scalar/Reassociate.cpp index e42e2c6..65c814d 100644 --- a/contrib/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -145,7 +145,8 @@ static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode1, return nullptr; } -void ReassociatePass::BuildRankMap(Function &F) { +void ReassociatePass::BuildRankMap(Function &F, + ReversePostOrderTraversal<Function*> &RPOT) { unsigned i = 2; // Assign distinct ranks to function arguments. @@ -154,7 +155,7 @@ void ReassociatePass::BuildRankMap(Function &F) { DEBUG(dbgs() << "Calculated Rank[" << I->getName() << "] = " << i << "\n"); } - ReversePostOrderTraversal<Function *> RPOT(&F); + // Traverse basic blocks in ReversePostOrder for (BasicBlock *BB : RPOT) { unsigned BBRank = RankMap[BB] = ++i << 16; @@ -507,9 +508,10 @@ static bool LinearizeExprTree(BinaryOperator *I, continue; } // No uses outside the expression, try morphing it. - } else if (It != Leaves.end()) { + } else { // Already in the leaf map. - assert(Visited.count(Op) && "In leaf map but not visited!"); + assert(It != Leaves.end() && Visited.count(Op) && + "In leaf map but not visited!"); // Update the number of paths to the leaf. IncorporateWeight(It->second, Weight, Opcode); @@ -1519,8 +1521,8 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, if (ConstantInt *CI = dyn_cast<ConstantInt>(Factor)) { if (CI->isNegative() && !CI->isMinValue(true)) { Factor = ConstantInt::get(CI->getContext(), -CI->getValue()); - assert(!Duplicates.count(Factor) && - "Shouldn't have two constant factors, missed a canonicalize"); + if (!Duplicates.insert(Factor).second) + continue; unsigned Occ = ++FactorOccurrences[Factor]; if (Occ > MaxOcc) { MaxOcc = Occ; @@ -1532,8 +1534,8 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, APFloat F(CF->getValueAPF()); F.changeSign(); Factor = ConstantFP::get(CF->getContext(), F); - assert(!Duplicates.count(Factor) && - "Shouldn't have two constant factors, missed a canonicalize"); + if (!Duplicates.insert(Factor).second) + continue; unsigned Occ = ++FactorOccurrences[Factor]; if (Occ > MaxOcc) { MaxOcc = Occ; @@ -1776,6 +1778,12 @@ Value *ReassociatePass::OptimizeMul(BinaryOperator *I, return nullptr; // All distinct factors, so nothing left for us to do. IRBuilder<> Builder(I); + // The reassociate transformation for FP operations is performed only + // if unsafe algebra is permitted by FastMathFlags. Propagate those flags + // to the newly generated operations. + if (auto FPI = dyn_cast<FPMathOperator>(I)) + Builder.setFastMathFlags(FPI->getFastMathFlags()); + Value *V = buildMinimalMultiplyDAG(Builder, Factors); if (Ops.empty()) return V; @@ -1863,6 +1871,8 @@ void ReassociatePass::RecursivelyEraseDeadInsts( /// Zap the given instruction, adding interesting operands to the work list. void ReassociatePass::EraseInst(Instruction *I) { assert(isInstructionTriviallyDead(I) && "Trivially dead instructions only!"); + DEBUG(dbgs() << "Erasing dead inst: "; I->dump()); + SmallVector<Value*, 8> Ops(I->op_begin(), I->op_end()); // Erase the dead instruction. ValueRankMap.erase(I); @@ -2172,11 +2182,19 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { } PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { + // Get the functions basic blocks in Reverse Post Order. This order is used by + // BuildRankMap to pre calculate ranks correctly. It also excludes dead basic + // blocks (it has been seen that the analysis in this pass could hang when + // analysing dead basic blocks). + ReversePostOrderTraversal<Function *> RPOT(&F); + // Calculate the rank map for F. - BuildRankMap(F); + BuildRankMap(F, RPOT); MadeChange = false; - for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) { + // Traverse the same blocks that was analysed by BuildRankMap. + for (BasicBlock *BI : RPOT) { + assert(RankMap.count(&*BI) && "BB should be ranked."); // Optimize every instruction in the basic block. for (BasicBlock::iterator II = BI->begin(), IE = BI->end(); II != IE;) if (isInstructionTriviallyDead(&*II)) { @@ -2196,8 +2214,10 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { // trivially dead instructions have been removed. while (!ToRedo.empty()) { Instruction *I = ToRedo.pop_back_val(); - if (isInstructionTriviallyDead(I)) + if (isInstructionTriviallyDead(I)) { RecursivelyEraseDeadInsts(I, ToRedo); + MadeChange = true; + } } // Now that we have removed dead instructions, we can reoptimize the diff --git a/contrib/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/contrib/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index bab39a3..1de7420 100644 --- a/contrib/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -453,7 +453,7 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) { if (isa<CallInst>(I) || isa<InvokeInst>(I)) return BaseDefiningValueResult(I, true); - // I have absolutely no idea how to implement this part yet. It's not + // TODO: I have absolutely no idea how to implement this part yet. It's not // necessarily hard, I just haven't really looked at it yet. assert(!isa<LandingPadInst>(I) && "Landing Pad is unimplemented"); @@ -676,7 +676,8 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { #ifndef NDEBUG auto isExpectedBDVType = [](Value *BDV) { return isa<PHINode>(BDV) || isa<SelectInst>(BDV) || - isa<ExtractElementInst>(BDV) || isa<InsertElementInst>(BDV); + isa<ExtractElementInst>(BDV) || isa<InsertElementInst>(BDV) || + isa<ShuffleVectorInst>(BDV); }; #endif @@ -719,9 +720,11 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { } else if (auto *IE = dyn_cast<InsertElementInst>(Current)) { visitIncomingValue(IE->getOperand(0)); // vector operand visitIncomingValue(IE->getOperand(1)); // scalar operand - } else { - // There is one known class of instructions we know we don't handle. - assert(isa<ShuffleVectorInst>(Current)); + } else if (auto *SV = dyn_cast<ShuffleVectorInst>(Current)) { + visitIncomingValue(SV->getOperand(0)); + visitIncomingValue(SV->getOperand(1)); + } + else { llvm_unreachable("Unimplemented instruction case"); } } @@ -778,12 +781,17 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { // useful in that it drives us to conflict if our input is. NewState = meetBDVState(NewState, getStateForInput(EE->getVectorOperand())); - } else { + } else if (auto *IE = dyn_cast<InsertElementInst>(BDV)){ // Given there's a inherent type mismatch between the operands, will // *always* produce Conflict. - auto *IE = cast<InsertElementInst>(BDV); NewState = meetBDVState(NewState, getStateForInput(IE->getOperand(0))); NewState = meetBDVState(NewState, getStateForInput(IE->getOperand(1))); + } else { + // The only instance this does not return a Conflict is when both the + // vector operands are the same vector. + auto *SV = cast<ShuffleVectorInst>(BDV); + NewState = meetBDVState(NewState, getStateForInput(SV->getOperand(0))); + NewState = meetBDVState(NewState, getStateForInput(SV->getOperand(1))); } BDVState OldState = States[BDV]; @@ -855,13 +863,18 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { std::string Name = suffixed_name_or(I, ".base", "base_ee"); return ExtractElementInst::Create(Undef, EE->getIndexOperand(), Name, EE); - } else { - auto *IE = cast<InsertElementInst>(I); + } else if (auto *IE = dyn_cast<InsertElementInst>(I)) { UndefValue *VecUndef = UndefValue::get(IE->getOperand(0)->getType()); UndefValue *ScalarUndef = UndefValue::get(IE->getOperand(1)->getType()); std::string Name = suffixed_name_or(I, ".base", "base_ie"); return InsertElementInst::Create(VecUndef, ScalarUndef, IE->getOperand(2), Name, IE); + } else { + auto *SV = cast<ShuffleVectorInst>(I); + UndefValue *VecUndef = UndefValue::get(SV->getOperand(0)->getType()); + std::string Name = suffixed_name_or(I, ".base", "base_sv"); + return new ShuffleVectorInst(VecUndef, VecUndef, SV->getOperand(2), + Name, SV); } }; Instruction *BaseInst = MakeBaseInstPlaceholder(I); @@ -963,8 +976,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { // Find the instruction which produces the base for each input. We may // need to insert a bitcast. BaseEE->setOperand(0, getBaseForInput(InVal, BaseEE)); - } else { - auto *BaseIE = cast<InsertElementInst>(State.getBaseValue()); + } else if (auto *BaseIE = dyn_cast<InsertElementInst>(State.getBaseValue())){ auto *BdvIE = cast<InsertElementInst>(BDV); auto UpdateOperand = [&](int OperandIdx) { Value *InVal = BdvIE->getOperand(OperandIdx); @@ -973,6 +985,16 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { }; UpdateOperand(0); // vector operand UpdateOperand(1); // scalar operand + } else { + auto *BaseSV = cast<ShuffleVectorInst>(State.getBaseValue()); + auto *BdvSV = cast<ShuffleVectorInst>(BDV); + auto UpdateOperand = [&](int OperandIdx) { + Value *InVal = BdvSV->getOperand(OperandIdx); + Value *Base = getBaseForInput(InVal, BaseSV); + BaseSV->setOperand(OperandIdx, Base); + }; + UpdateOperand(0); // vector operand + UpdateOperand(1); // vector operand } } @@ -1154,7 +1176,7 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables, return; auto FindIndex = [](ArrayRef<Value *> LiveVec, Value *Val) { - auto ValIt = std::find(LiveVec.begin(), LiveVec.end(), Val); + auto ValIt = find(LiveVec, Val); assert(ValIt != LiveVec.end() && "Val not found in LiveVec!"); size_t Index = std::distance(LiveVec.begin(), ValIt); assert(Index < LiveVec.size() && "Bug in std::find?"); @@ -1273,6 +1295,24 @@ public: }; } +static StringRef getDeoptLowering(CallSite CS) { + const char *DeoptLowering = "deopt-lowering"; + if (CS.hasFnAttr(DeoptLowering)) { + // FIXME: CallSite has a *really* confusing interface around attributes + // with values. + const AttributeSet &CSAS = CS.getAttributes(); + if (CSAS.hasAttribute(AttributeSet::FunctionIndex, + DeoptLowering)) + return CSAS.getAttribute(AttributeSet::FunctionIndex, + DeoptLowering).getValueAsString(); + Function *F = CS.getCalledFunction(); + assert(F && F->hasFnAttribute(DeoptLowering)); + return F->getFnAttribute(DeoptLowering).getValueAsString(); + } + return "live-through"; +} + + static void makeStatepointExplicitImpl(const CallSite CS, /* to replace */ const SmallVectorImpl<Value *> &BasePtrs, @@ -1314,6 +1354,14 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ if (SD.StatepointID) StatepointID = *SD.StatepointID; + // Pass through the requested lowering if any. The default is live-through. + StringRef DeoptLowering = getDeoptLowering(CS); + if (DeoptLowering.equals("live-in")) + Flags |= uint32_t(StatepointFlags::DeoptLiveIn); + else { + assert(DeoptLowering.equals("live-through") && "Unsupported value!"); + } + Value *CallTarget = CS.getCalledValue(); if (Function *F = dyn_cast<Function>(CallTarget)) { if (F->getIntrinsicID() == Intrinsic::experimental_deoptimize) { @@ -1347,7 +1395,7 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */ StatepointID, NumPatchBytes, CallTarget, Flags, CallArgs, TransitionArgs, DeoptArgs, GCArgs, "safepoint_token"); - Call->setTailCall(ToReplace->isTailCall()); + Call->setTailCallKind(ToReplace->getTailCallKind()); Call->setCallingConv(ToReplace->getCallingConv()); // Currently we will fail on parameter attributes and on certain @@ -1740,9 +1788,8 @@ static void relocationViaAlloca( /// tests in ways which make them less useful in testing fused safepoints. template <typename T> static void unique_unsorted(SmallVectorImpl<T> &Vec) { SmallSet<T, 8> Seen; - Vec.erase(std::remove_if(Vec.begin(), Vec.end(), [&](const T &V) { - return !Seen.insert(V).second; - }), Vec.end()); + Vec.erase(remove_if(Vec, [&](const T &V) { return !Seen.insert(V).second; }), + Vec.end()); } /// Insert holders so that each Value is obviously live through the entire @@ -1784,38 +1831,33 @@ static void findLiveReferences( } // Helper function for the "rematerializeLiveValues". It walks use chain -// starting from the "CurrentValue" until it meets "BaseValue". Only "simple" -// values are visited (currently it is GEP's and casts). Returns true if it -// successfully reached "BaseValue" and false otherwise. -// Fills "ChainToBase" array with all visited values. "BaseValue" is not -// recorded. -static bool findRematerializableChainToBasePointer( +// starting from the "CurrentValue" until it reaches the root of the chain, i.e. +// the base or a value it cannot process. Only "simple" values are processed +// (currently it is GEP's and casts). The returned root is examined by the +// callers of findRematerializableChainToBasePointer. Fills "ChainToBase" array +// with all visited values. +static Value* findRematerializableChainToBasePointer( SmallVectorImpl<Instruction*> &ChainToBase, - Value *CurrentValue, Value *BaseValue) { - - // We have found a base value - if (CurrentValue == BaseValue) { - return true; - } + Value *CurrentValue) { if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(CurrentValue)) { ChainToBase.push_back(GEP); return findRematerializableChainToBasePointer(ChainToBase, - GEP->getPointerOperand(), - BaseValue); + GEP->getPointerOperand()); } if (CastInst *CI = dyn_cast<CastInst>(CurrentValue)) { if (!CI->isNoopCast(CI->getModule()->getDataLayout())) - return false; + return CI; ChainToBase.push_back(CI); return findRematerializableChainToBasePointer(ChainToBase, - CI->getOperand(0), BaseValue); + CI->getOperand(0)); } - // Not supported instruction in the chain - return false; + // We have reached the root of the chain, which is either equal to the base or + // is the first unsupported value along the use chain. + return CurrentValue; } // Helper function for the "rematerializeLiveValues". Compute cost of the use @@ -1852,6 +1894,34 @@ chainToBasePointerCost(SmallVectorImpl<Instruction*> &Chain, return Cost; } +static bool AreEquivalentPhiNodes(PHINode &OrigRootPhi, PHINode &AlternateRootPhi) { + + unsigned PhiNum = OrigRootPhi.getNumIncomingValues(); + if (PhiNum != AlternateRootPhi.getNumIncomingValues() || + OrigRootPhi.getParent() != AlternateRootPhi.getParent()) + return false; + // Map of incoming values and their corresponding basic blocks of + // OrigRootPhi. + SmallDenseMap<Value *, BasicBlock *, 8> CurrentIncomingValues; + for (unsigned i = 0; i < PhiNum; i++) + CurrentIncomingValues[OrigRootPhi.getIncomingValue(i)] = + OrigRootPhi.getIncomingBlock(i); + + // Both current and base PHIs should have same incoming values and + // the same basic blocks corresponding to the incoming values. + for (unsigned i = 0; i < PhiNum; i++) { + auto CIVI = + CurrentIncomingValues.find(AlternateRootPhi.getIncomingValue(i)); + if (CIVI == CurrentIncomingValues.end()) + return false; + BasicBlock *CurrentIncomingBB = CIVI->second; + if (CurrentIncomingBB != AlternateRootPhi.getIncomingBlock(i)) + return false; + } + return true; + +} + // From the statepoint live set pick values that are cheaper to recompute then // to relocate. Remove this values from the live set, rematerialize them after // statepoint and record them in "Info" structure. Note that similar to @@ -1869,16 +1939,38 @@ static void rematerializeLiveValues(CallSite CS, // For each live pointer find it's defining chain SmallVector<Instruction *, 3> ChainToBase; assert(Info.PointerToBase.count(LiveValue)); - bool FoundChain = + Value *RootOfChain = findRematerializableChainToBasePointer(ChainToBase, - LiveValue, - Info.PointerToBase[LiveValue]); + LiveValue); + // Nothing to do, or chain is too long - if (!FoundChain || - ChainToBase.size() == 0 || + if ( ChainToBase.size() == 0 || ChainToBase.size() > ChainLengthThreshold) continue; + // Handle the scenario where the RootOfChain is not equal to the + // Base Value, but they are essentially the same phi values. + if (RootOfChain != Info.PointerToBase[LiveValue]) { + PHINode *OrigRootPhi = dyn_cast<PHINode>(RootOfChain); + PHINode *AlternateRootPhi = dyn_cast<PHINode>(Info.PointerToBase[LiveValue]); + if (!OrigRootPhi || !AlternateRootPhi) + continue; + // PHI nodes that have the same incoming values, and belonging to the same + // basic blocks are essentially the same SSA value. When the original phi + // has incoming values with different base pointers, the original phi is + // marked as conflict, and an additional `AlternateRootPhi` with the same + // incoming values get generated by the findBasePointer function. We need + // to identify the newly generated AlternateRootPhi (.base version of phi) + // and RootOfChain (the original phi node itself) are the same, so that we + // can rematerialize the gep and casts. This is a workaround for the + // deficieny in the findBasePointer algorithm. + if (!AreEquivalentPhiNodes(*OrigRootPhi, *AlternateRootPhi)) + continue; + // Now that the phi nodes are proved to be the same, assert that + // findBasePointer's newly generated AlternateRootPhi is present in the + // liveset of the call. + assert(Info.LiveSet.count(AlternateRootPhi)); + } // Compute cost of this chain unsigned Cost = chainToBasePointerCost(ChainToBase, TTI); // TODO: We can also account for cases when we will be able to remove some @@ -1906,7 +1998,8 @@ static void rematerializeLiveValues(CallSite CS, // Utility function which clones all instructions from "ChainToBase" // and inserts them before "InsertBefore". Returns rematerialized value // which should be used after statepoint. - auto rematerializeChain = [&ChainToBase](Instruction *InsertBefore) { + auto rematerializeChain = [&ChainToBase]( + Instruction *InsertBefore, Value *RootOfChain, Value *AlternateLiveBase) { Instruction *LastClonedValue = nullptr; Instruction *LastValue = nullptr; for (Instruction *Instr: ChainToBase) { @@ -1926,14 +2019,24 @@ static void rematerializeLiveValues(CallSite CS, assert(LastValue); ClonedValue->replaceUsesOfWith(LastValue, LastClonedValue); #ifndef NDEBUG - // Assert that cloned instruction does not use any instructions from - // this chain other than LastClonedValue for (auto OpValue : ClonedValue->operand_values()) { - assert(std::find(ChainToBase.begin(), ChainToBase.end(), OpValue) == - ChainToBase.end() && + // Assert that cloned instruction does not use any instructions from + // this chain other than LastClonedValue + assert(!is_contained(ChainToBase, OpValue) && "incorrect use in rematerialization chain"); + // Assert that the cloned instruction does not use the RootOfChain + // or the AlternateLiveBase. + assert(OpValue != RootOfChain && OpValue != AlternateLiveBase); } #endif + } else { + // For the first instruction, replace the use of unrelocated base i.e. + // RootOfChain/OrigRootPhi, with the corresponding PHI present in the + // live set. They have been proved to be the same PHI nodes. Note + // that the *only* use of the RootOfChain in the ChainToBase list is + // the first Value in the list. + if (RootOfChain != AlternateLiveBase) + ClonedValue->replaceUsesOfWith(RootOfChain, AlternateLiveBase); } LastClonedValue = ClonedValue; @@ -1948,7 +2051,8 @@ static void rematerializeLiveValues(CallSite CS, if (CS.isCall()) { Instruction *InsertBefore = CS.getInstruction()->getNextNode(); assert(InsertBefore); - Instruction *RematerializedValue = rematerializeChain(InsertBefore); + Instruction *RematerializedValue = rematerializeChain( + InsertBefore, RootOfChain, Info.PointerToBase[LiveValue]); Info.RematerializedValues[RematerializedValue] = LiveValue; } else { InvokeInst *Invoke = cast<InvokeInst>(CS.getInstruction()); @@ -1958,10 +2062,10 @@ static void rematerializeLiveValues(CallSite CS, Instruction *UnwindInsertBefore = &*Invoke->getUnwindDest()->getFirstInsertionPt(); - Instruction *NormalRematerializedValue = - rematerializeChain(NormalInsertBefore); - Instruction *UnwindRematerializedValue = - rematerializeChain(UnwindInsertBefore); + Instruction *NormalRematerializedValue = rematerializeChain( + NormalInsertBefore, RootOfChain, Info.PointerToBase[LiveValue]); + Instruction *UnwindRematerializedValue = rematerializeChain( + UnwindInsertBefore, RootOfChain, Info.PointerToBase[LiveValue]); Info.RematerializedValues[NormalRematerializedValue] = LiveValue; Info.RematerializedValues[UnwindRematerializedValue] = LiveValue; @@ -2268,8 +2372,7 @@ static bool shouldRewriteStatepointsIn(Function &F) { void RewriteStatepointsForGC::stripNonValidAttributes(Module &M) { #ifndef NDEBUG - assert(std::any_of(M.begin(), M.end(), shouldRewriteStatepointsIn) && - "precondition!"); + assert(any_of(M, shouldRewriteStatepointsIn) && "precondition!"); #endif for (Function &F : M) @@ -2546,8 +2649,8 @@ static void findLiveSetAtInst(Instruction *Inst, GCPtrLivenessData &Data, // call result is not live (normal), nor are it's arguments // (unless they're used again later). This adjustment is // specifically what we need to relocate - BasicBlock::reverse_iterator rend(Inst->getIterator()); - computeLiveInValues(BB->rbegin(), rend, LiveOut); + computeLiveInValues(BB->rbegin(), ++Inst->getIterator().getReverse(), + LiveOut); LiveOut.remove(Inst); Out.insert(LiveOut.begin(), LiveOut.end()); } diff --git a/contrib/llvm/lib/Transforms/Scalar/SCCP.cpp b/contrib/llvm/lib/Transforms/Scalar/SCCP.cpp index f74f28a..ede381c 100644 --- a/contrib/llvm/lib/Transforms/Scalar/SCCP.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/SCCP.cpp @@ -242,7 +242,7 @@ public: /// this method must be called. void AddTrackedFunction(Function *F) { // Add an entry, F -> undef. - if (StructType *STy = dyn_cast<StructType>(F->getReturnType())) { + if (auto *STy = dyn_cast<StructType>(F->getReturnType())) { MRVFunctionsTracked.insert(F); for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) TrackedMultipleRetVals.insert(std::make_pair(std::make_pair(F, i), @@ -272,7 +272,7 @@ public: std::vector<LatticeVal> getStructLatticeValueFor(Value *V) const { std::vector<LatticeVal> StructValues; - StructType *STy = dyn_cast<StructType>(V->getType()); + auto *STy = dyn_cast<StructType>(V->getType()); assert(STy && "getStructLatticeValueFor() can be called only on structs"); for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { auto I = StructValueState.find(std::make_pair(V, i)); @@ -300,23 +300,44 @@ public: return TrackedGlobals; } + /// getMRVFunctionsTracked - Get the set of functions which return multiple + /// values tracked by the pass. + const SmallPtrSet<Function *, 16> getMRVFunctionsTracked() { + return MRVFunctionsTracked; + } + void markOverdefined(Value *V) { - assert(!V->getType()->isStructTy() && "Should use other method"); + assert(!V->getType()->isStructTy() && + "structs should use markAnythingOverdefined"); markOverdefined(ValueState[V], V); } /// markAnythingOverdefined - Mark the specified value overdefined. This /// works with both scalars and structs. void markAnythingOverdefined(Value *V) { - if (StructType *STy = dyn_cast<StructType>(V->getType())) + if (auto *STy = dyn_cast<StructType>(V->getType())) for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) markOverdefined(getStructValueState(V, i), V); else markOverdefined(V); } + // isStructLatticeConstant - Return true if all the lattice values + // corresponding to elements of the structure are not overdefined, + // false otherwise. + bool isStructLatticeConstant(Function *F, StructType *STy) { + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + const auto &It = TrackedMultipleRetVals.find(std::make_pair(F, i)); + assert(It != TrackedMultipleRetVals.end()); + LatticeVal LV = It->second; + if (LV.isOverdefined()) + return false; + } + return true; + } + private: - // pushToWorkList - Helper for markConstant/markForcedConstant + // pushToWorkList - Helper for markConstant/markForcedConstant/markOverdefined void pushToWorkList(LatticeVal &IV, Value *V) { if (IV.isOverdefined()) return OverdefinedInstWorkList.push_back(V); @@ -334,12 +355,12 @@ private: } void markConstant(Value *V, Constant *C) { - assert(!V->getType()->isStructTy() && "Should use other method"); + assert(!V->getType()->isStructTy() && "structs should use mergeInValue"); markConstant(ValueState[V], V, C); } void markForcedConstant(Value *V, Constant *C) { - assert(!V->getType()->isStructTy() && "Should use other method"); + assert(!V->getType()->isStructTy() && "structs should use mergeInValue"); LatticeVal &IV = ValueState[V]; IV.markForcedConstant(C); DEBUG(dbgs() << "markForcedConstant: " << *C << ": " << *V << '\n'); @@ -354,12 +375,12 @@ private: if (!IV.markOverdefined()) return; DEBUG(dbgs() << "markOverdefined: "; - if (Function *F = dyn_cast<Function>(V)) + if (auto *F = dyn_cast<Function>(V)) dbgs() << "Function '" << F->getName() << "'\n"; else dbgs() << *V << '\n'); // Only instructions go on the work list - OverdefinedInstWorkList.push_back(V); + pushToWorkList(IV, V); } void mergeInValue(LatticeVal &IV, Value *V, LatticeVal MergeWithV) { @@ -374,7 +395,8 @@ private: } void mergeInValue(Value *V, LatticeVal MergeWithV) { - assert(!V->getType()->isStructTy() && "Should use other method"); + assert(!V->getType()->isStructTy() && + "non-structs should use markConstant"); mergeInValue(ValueState[V], V, MergeWithV); } @@ -392,7 +414,7 @@ private: if (!I.second) return LV; // Common case, already in the map. - if (Constant *C = dyn_cast<Constant>(V)) { + if (auto *C = dyn_cast<Constant>(V)) { // Undef values remain unknown. if (!isa<UndefValue>(V)) LV.markConstant(C); // Constants are constant @@ -418,7 +440,7 @@ private: if (!I.second) return LV; // Common case, already in the map. - if (Constant *C = dyn_cast<Constant>(V)) { + if (auto *C = dyn_cast<Constant>(V)) { Constant *Elt = C->getAggregateElement(i); if (!Elt) @@ -489,9 +511,6 @@ private: void visitSelectInst(SelectInst &I); void visitBinaryOperator(Instruction &I); void visitCmpInst(CmpInst &I); - void visitExtractElementInst(ExtractElementInst &I); - void visitInsertElementInst(InsertElementInst &I); - void visitShuffleVectorInst(ShuffleVectorInst &I); void visitExtractValueInst(ExtractValueInst &EVI); void visitInsertValueInst(InsertValueInst &IVI); void visitLandingPadInst(LandingPadInst &I) { markAnythingOverdefined(&I); } @@ -527,7 +546,7 @@ private: void visitInstruction(Instruction &I) { // If a new instruction is added to LLVM that we don't handle. - dbgs() << "SCCP: Don't know how to handle: " << I << '\n'; + DEBUG(dbgs() << "SCCP: Don't know how to handle: " << I << '\n'); markAnythingOverdefined(&I); // Just in case } }; @@ -541,7 +560,7 @@ private: void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI, SmallVectorImpl<bool> &Succs) { Succs.resize(TI.getNumSuccessors()); - if (BranchInst *BI = dyn_cast<BranchInst>(&TI)) { + if (auto *BI = dyn_cast<BranchInst>(&TI)) { if (BI->isUnconditional()) { Succs[0] = true; return; @@ -568,7 +587,7 @@ void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI, return; } - if (SwitchInst *SI = dyn_cast<SwitchInst>(&TI)) { + if (auto *SI = dyn_cast<SwitchInst>(&TI)) { if (!SI->getNumCases()) { Succs[0] = true; return; @@ -594,9 +613,7 @@ void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI, return; } -#ifndef NDEBUG - dbgs() << "Unknown terminator instruction: " << TI << '\n'; -#endif + DEBUG(dbgs() << "Unknown terminator instruction: " << TI << '\n'); llvm_unreachable("SCCP: Don't know how to handle this terminator!"); } @@ -612,7 +629,7 @@ bool SCCPSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To) { // Check to make sure this edge itself is actually feasible now. TerminatorInst *TI = From->getTerminator(); - if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { + if (auto *BI = dyn_cast<BranchInst>(TI)) { if (BI->isUnconditional()) return true; @@ -632,7 +649,7 @@ bool SCCPSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To) { if (TI->isExceptional()) return true; - if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { + if (auto *SI = dyn_cast<SwitchInst>(TI)) { if (SI->getNumCases() < 1) return true; @@ -650,9 +667,7 @@ bool SCCPSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To) { if (isa<IndirectBrInst>(TI)) return true; -#ifndef NDEBUG - dbgs() << "Unknown terminator instruction: " << *TI << '\n'; -#endif + DEBUG(dbgs() << "Unknown terminator instruction: " << *TI << '\n'); llvm_unreachable("SCCP: Don't know how to handle this terminator!"); } @@ -747,7 +762,7 @@ void SCCPSolver::visitReturnInst(ReturnInst &I) { // Handle functions that return multiple values. if (!TrackedMultipleRetVals.empty()) { - if (StructType *STy = dyn_cast<StructType>(ResultOp->getType())) + if (auto *STy = dyn_cast<StructType>(ResultOp->getType())) if (MRVFunctionsTracked.count(F)) for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) mergeInValue(TrackedMultipleRetVals[std::make_pair(F, i)], F, @@ -806,7 +821,7 @@ void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) { } void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) { - StructType *STy = dyn_cast<StructType>(IVI.getType()); + auto *STy = dyn_cast<StructType>(IVI.getType()); if (!STy) return markOverdefined(&IVI); @@ -898,7 +913,8 @@ void SCCPSolver::visitBinaryOperator(Instruction &I) { // If this is an AND or OR with 0 or -1, it doesn't matter that the other // operand is overdefined. - if (I.getOpcode() == Instruction::And || I.getOpcode() == Instruction::Or) { + if (I.getOpcode() == Instruction::And || I.getOpcode() == Instruction::Mul || + I.getOpcode() == Instruction::Or) { LatticeVal *NonOverdefVal = nullptr; if (!V1State.isOverdefined()) NonOverdefVal = &V1State; @@ -906,25 +922,19 @@ void SCCPSolver::visitBinaryOperator(Instruction &I) { NonOverdefVal = &V2State; if (NonOverdefVal) { - if (NonOverdefVal->isUnknown()) { - // Could annihilate value. - if (I.getOpcode() == Instruction::And) - markConstant(IV, &I, Constant::getNullValue(I.getType())); - else if (VectorType *PT = dyn_cast<VectorType>(I.getType())) - markConstant(IV, &I, Constant::getAllOnesValue(PT)); - else - markConstant(IV, &I, - Constant::getAllOnesValue(I.getType())); + if (NonOverdefVal->isUnknown()) return; - } - if (I.getOpcode() == Instruction::And) { + if (I.getOpcode() == Instruction::And || + I.getOpcode() == Instruction::Mul) { // X and 0 = 0 + // X * 0 = 0 if (NonOverdefVal->getConstant()->isNullValue()) return markConstant(IV, &I, NonOverdefVal->getConstant()); } else { + // X or -1 = -1 if (ConstantInt *CI = NonOverdefVal->getConstantInt()) - if (CI->isAllOnesValue()) // X or -1 = -1 + if (CI->isAllOnesValue()) return markConstant(IV, &I, NonOverdefVal->getConstant()); } } @@ -957,21 +967,6 @@ void SCCPSolver::visitCmpInst(CmpInst &I) { markOverdefined(&I); } -void SCCPSolver::visitExtractElementInst(ExtractElementInst &I) { - // TODO : SCCP does not handle vectors properly. - return markOverdefined(&I); -} - -void SCCPSolver::visitInsertElementInst(InsertElementInst &I) { - // TODO : SCCP does not handle vectors properly. - return markOverdefined(&I); -} - -void SCCPSolver::visitShuffleVectorInst(ShuffleVectorInst &I) { - // TODO : SCCP does not handle vectors properly. - return markOverdefined(&I); -} - // Handle getelementptr instructions. If all operands are constants then we // can turn this into a getelementptr ConstantExpr. // @@ -1044,7 +1039,7 @@ void SCCPSolver::visitLoadInst(LoadInst &I) { return; // Transform load (constant global) into the value loaded. - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr)) { + if (auto *GV = dyn_cast<GlobalVariable>(Ptr)) { if (!TrackedGlobals.empty()) { // If we are tracking this global, merge in the known value for it. DenseMap<GlobalVariable*, LatticeVal>::iterator It = @@ -1132,7 +1127,7 @@ CallOverdefined: continue; } - if (StructType *STy = dyn_cast<StructType>(AI->getType())) { + if (auto *STy = dyn_cast<StructType>(AI->getType())) { for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { LatticeVal CallArg = getStructValueState(*CAI, i); mergeInValue(getStructValueState(&*AI, i), &*AI, CallArg); @@ -1144,7 +1139,7 @@ CallOverdefined: } // If this is a single/zero retval case, see if we're tracking the function. - if (StructType *STy = dyn_cast<StructType>(F->getReturnType())) { + if (auto *STy = dyn_cast<StructType>(F->getReturnType())) { if (!MRVFunctionsTracked.count(F)) goto CallOverdefined; // Not tracking this callee. @@ -1182,7 +1177,7 @@ void SCCPSolver::Solve() { // Update all of the users of this instruction's value. // for (User *U : I->users()) - if (Instruction *UI = dyn_cast<Instruction>(U)) + if (auto *UI = dyn_cast<Instruction>(U)) OperandChangedState(UI); } @@ -1201,7 +1196,7 @@ void SCCPSolver::Solve() { // if (I->getType()->isStructTy() || !getValueState(I).isOverdefined()) for (User *U : I->users()) - if (Instruction *UI = dyn_cast<Instruction>(U)) + if (auto *UI = dyn_cast<Instruction>(U)) OperandChangedState(UI); } @@ -1246,7 +1241,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { // Look for instructions which produce undef values. if (I.getType()->isVoidTy()) continue; - if (StructType *STy = dyn_cast<StructType>(I.getType())) { + if (auto *STy = dyn_cast<StructType>(I.getType())) { // Only a few things that can be structs matter for undef. // Tracked calls must never be marked overdefined in ResolvedUndefsIn. @@ -1386,8 +1381,8 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { break; } - // undef >>a X -> all ones - markForcedConstant(&I, Constant::getAllOnesValue(ITy)); + // undef >>a X -> 0 + markForcedConstant(&I, Constant::getNullValue(ITy)); return true; case Instruction::LShr: case Instruction::Shl: @@ -1467,7 +1462,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { // we force the branch to go one way or the other to make the successor // values live. It doesn't really matter which way we force it. TerminatorInst *TI = BB.getTerminator(); - if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { + if (auto *BI = dyn_cast<BranchInst>(TI)) { if (!BI->isConditional()) continue; if (!getValueState(BI->getCondition()).isUnknown()) continue; @@ -1488,7 +1483,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { return true; } - if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { + if (auto *SI = dyn_cast<SwitchInst>(TI)) { if (!SI->getNumCases() || !getValueState(SI->getCondition()).isUnknown()) continue; @@ -1512,11 +1507,10 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { Constant *Const = nullptr; if (V->getType()->isStructTy()) { std::vector<LatticeVal> IVs = Solver.getStructLatticeValueFor(V); - if (std::any_of(IVs.begin(), IVs.end(), - [](LatticeVal &LV) { return LV.isOverdefined(); })) + if (any_of(IVs, [](const LatticeVal &LV) { return LV.isOverdefined(); })) return false; std::vector<Constant *> ConstVals; - StructType *ST = dyn_cast<StructType>(V->getType()); + auto *ST = dyn_cast<StructType>(V->getType()); for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { LatticeVal V = IVs[i]; ConstVals.push_back(V.isConstant() @@ -1599,7 +1593,7 @@ static bool runSCCP(Function &F, const DataLayout &DL, return MadeChanges; } -PreservedAnalyses SCCPPass::run(Function &F, AnalysisManager<Function> &AM) { +PreservedAnalyses SCCPPass::run(Function &F, FunctionAnalysisManager &AM) { const DataLayout &DL = F.getParent()->getDataLayout(); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); if (!runSCCP(F, DL, &TLI)) @@ -1657,7 +1651,7 @@ static bool AddressIsTaken(const GlobalValue *GV) { for (const Use &U : GV->uses()) { const User *UR = U.getUser(); - if (const StoreInst *SI = dyn_cast<StoreInst>(UR)) { + if (const auto *SI = dyn_cast<StoreInst>(UR)) { if (SI->getOperand(0) == GV || SI->isVolatile()) return true; // Storing addr of GV. } else if (isa<InvokeInst>(UR) || isa<CallInst>(UR)) { @@ -1665,7 +1659,7 @@ static bool AddressIsTaken(const GlobalValue *GV) { ImmutableCallSite CS(cast<Instruction>(UR)); if (!CS.isCallee(&U)) return true; - } else if (const LoadInst *LI = dyn_cast<LoadInst>(UR)) { + } else if (const auto *LI = dyn_cast<LoadInst>(UR)) { if (LI->isVolatile()) return true; } else if (isa<BlockAddress>(UR)) { @@ -1678,6 +1672,19 @@ static bool AddressIsTaken(const GlobalValue *GV) { return false; } +static void findReturnsToZap(Function &F, + SmallPtrSet<Function *, 32> &AddressTakenFunctions, + SmallVector<ReturnInst *, 8> &ReturnsToZap) { + // We can only do this if we know that nothing else can call the function. + if (!F.hasLocalLinkage() || AddressTakenFunctions.count(&F)) + return; + + for (BasicBlock &BB : F) + if (auto *RI = dyn_cast<ReturnInst>(BB.getTerminator())) + if (!isa<UndefValue>(RI->getOperand(0))) + ReturnsToZap.push_back(RI); +} + static bool runIPSCCP(Module &M, const DataLayout &DL, const TargetLibraryInfo *TLI) { SCCPSolver Solver(DL, TLI); @@ -1698,7 +1705,10 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, // If this is an exact definition of this function, then we can propagate // information about its result into callsites of it. - if (F.hasExactDefinition()) + // Don't touch naked functions. They may contain asm returning a + // value we don't see, so we may end up interprocedurally propagating + // the return value incorrectly. + if (F.hasExactDefinition() && !F.hasFnAttribute(Attribute::Naked)) Solver.AddTrackedFunction(&F); // If this function only has direct calls that we can see, we can track its @@ -1800,7 +1810,7 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, UI != UE;) { // Grab the user and then increment the iterator early, as the user // will be deleted. Step past all adjacent uses from the same user. - Instruction *I = dyn_cast<Instruction>(*UI); + auto *I = dyn_cast<Instruction>(*UI); do { ++UI; } while (UI != UE && *UI == I); // Ignore blockaddress users; BasicBlock's dtor will handle them. @@ -1812,10 +1822,10 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, // if this is a branch or switch on undef. Fold it manually as a // branch to the first successor. #ifndef NDEBUG - if (BranchInst *BI = dyn_cast<BranchInst>(I)) { + if (auto *BI = dyn_cast<BranchInst>(I)) { assert(BI->isConditional() && isa<UndefValue>(BI->getCondition()) && "Branch should be foldable!"); - } else if (SwitchInst *SI = dyn_cast<SwitchInst>(I)) { + } else if (auto *SI = dyn_cast<SwitchInst>(I)) { assert(isa<UndefValue>(SI->getCondition()) && "Switch should fold"); } else { llvm_unreachable("Didn't fold away reference to block!"); @@ -1853,21 +1863,20 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, // whether other functions are optimizable. SmallVector<ReturnInst*, 8> ReturnsToZap; - // TODO: Process multiple value ret instructions also. const DenseMap<Function*, LatticeVal> &RV = Solver.getTrackedRetVals(); for (const auto &I : RV) { Function *F = I.first; if (I.second.isOverdefined() || F->getReturnType()->isVoidTy()) continue; + findReturnsToZap(*F, AddressTakenFunctions, ReturnsToZap); + } - // We can only do this if we know that nothing else can call the function. - if (!F->hasLocalLinkage() || AddressTakenFunctions.count(F)) - continue; - - for (BasicBlock &BB : *F) - if (ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator())) - if (!isa<UndefValue>(RI->getOperand(0))) - ReturnsToZap.push_back(RI); + for (const auto &F : Solver.getMRVFunctionsTracked()) { + assert(F->getReturnType()->isStructTy() && + "The return type should be a struct"); + StructType *STy = cast<StructType>(F->getReturnType()); + if (Solver.isStructLatticeConstant(F, STy)) + findReturnsToZap(*F, AddressTakenFunctions, ReturnsToZap); } // Zap all returns which we've identified as zap to change. @@ -1896,7 +1905,7 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, return MadeChanges; } -PreservedAnalyses IPSCCPPass::run(Module &M, AnalysisManager<Module> &AM) { +PreservedAnalyses IPSCCPPass::run(Module &M, ModuleAnalysisManager &AM) { const DataLayout &DL = M.getDataLayout(); auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); if (!runIPSCCP(M, DL, &TLI)) diff --git a/contrib/llvm/lib/Transforms/Scalar/SROA.cpp b/contrib/llvm/lib/Transforms/Scalar/SROA.cpp index 4ce552f..bfcb155 100644 --- a/contrib/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/SROA.cpp @@ -44,12 +44,12 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Operator.h" #include "llvm/Pass.h" +#include "llvm/Support/Chrono.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" -#include "llvm/Support/TimeValue.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" @@ -432,19 +432,18 @@ class AllocaSlices::partition_iterator // cannot change the max split slice end because we just checked that // the prior partition ended prior to that max. P.SplitTails.erase( - std::remove_if( - P.SplitTails.begin(), P.SplitTails.end(), - [&](Slice *S) { return S->endOffset() <= P.EndOffset; }), + remove_if(P.SplitTails, + [&](Slice *S) { return S->endOffset() <= P.EndOffset; }), P.SplitTails.end()); - assert(std::any_of(P.SplitTails.begin(), P.SplitTails.end(), - [&](Slice *S) { - return S->endOffset() == MaxSplitSliceEndOffset; - }) && + assert(any_of(P.SplitTails, + [&](Slice *S) { + return S->endOffset() == MaxSplitSliceEndOffset; + }) && "Could not find the current max split slice offset!"); - assert(std::all_of(P.SplitTails.begin(), P.SplitTails.end(), - [&](Slice *S) { - return S->endOffset() <= MaxSplitSliceEndOffset; - }) && + assert(all_of(P.SplitTails, + [&](Slice *S) { + return S->endOffset() <= MaxSplitSliceEndOffset; + }) && "Max split slice end offset is not actually the max!"); } } @@ -693,7 +692,7 @@ private: break; // Handle a struct index, which adds its field offset to the pointer. - if (StructType *STy = dyn_cast<StructType>(*GTI)) { + if (StructType *STy = GTI.getStructTypeOrNull()) { unsigned ElementIdx = OpC->getZExtValue(); const StructLayout *SL = DL.getStructLayout(STy); GEPOffset += @@ -996,15 +995,13 @@ AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI) return; } - Slices.erase(std::remove_if(Slices.begin(), Slices.end(), - [](const Slice &S) { - return S.isDead(); - }), + Slices.erase(remove_if(Slices, [](const Slice &S) { return S.isDead(); }), Slices.end()); #ifndef NDEBUG if (SROARandomShuffleSlices) { - std::mt19937 MT(static_cast<unsigned>(sys::TimeValue::now().msec())); + std::mt19937 MT(static_cast<unsigned>( + std::chrono::system_clock::now().time_since_epoch().count())); std::shuffle(Slices.begin(), Slices.end(), MT); } #endif @@ -1815,10 +1812,10 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { // do that until all the backends are known to produce good code for all // integer vector types. if (!HaveCommonEltTy) { - CandidateTys.erase(std::remove_if(CandidateTys.begin(), CandidateTys.end(), - [](VectorType *VTy) { - return !VTy->getElementType()->isIntegerTy(); - }), + CandidateTys.erase(remove_if(CandidateTys, + [](VectorType *VTy) { + return !VTy->getElementType()->isIntegerTy(); + }), CandidateTys.end()); // If there were no integer vector types, give up. @@ -2486,8 +2483,8 @@ private: } V = convertValue(DL, IRB, V, NewAllocaTy); StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment()); + Store->copyMetadata(SI, LLVMContext::MD_mem_parallel_loop_access); Pass.DeadInsts.insert(&SI); - (void)Store; DEBUG(dbgs() << " to: " << *Store << "\n"); return true; } @@ -2549,6 +2546,7 @@ private: NewSI = IRB.CreateAlignedStore(V, NewPtr, getSliceAlign(V->getType()), SI.isVolatile()); } + NewSI->copyMetadata(SI, LLVMContext::MD_mem_parallel_loop_access); if (SI.isVolatile()) NewSI->setAtomic(SI.getOrdering(), SI.getSynchScope()); Pass.DeadInsts.insert(&SI); @@ -2878,6 +2876,17 @@ private: // Record this instruction for deletion. Pass.DeadInsts.insert(&II); + // Lifetime intrinsics are only promotable if they cover the whole alloca. + // Therefore, we drop lifetime intrinsics which don't cover the whole + // alloca. + // (In theory, intrinsics which partially cover an alloca could be + // promoted, but PromoteMemToReg doesn't handle that case.) + // FIXME: Check whether the alloca is promotable before dropping the + // lifetime intrinsics? + if (NewBeginOffset != NewAllocaBeginOffset || + NewEndOffset != NewAllocaEndOffset) + return true; + ConstantInt *Size = ConstantInt::get(cast<IntegerType>(II.getArgOperand(0)->getType()), NewEndOffset - NewBeginOffset); @@ -2890,6 +2899,7 @@ private: (void)New; DEBUG(dbgs() << " to: " << *New << "\n"); + return true; } @@ -3209,20 +3219,11 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, return nullptr; if (SequentialType *SeqTy = dyn_cast<SequentialType>(Ty)) { - // We can't partition pointers... - if (SeqTy->isPointerTy()) - return nullptr; - Type *ElementTy = SeqTy->getElementType(); uint64_t ElementSize = DL.getTypeAllocSize(ElementTy); uint64_t NumSkippedElements = Offset / ElementSize; - if (ArrayType *ArrTy = dyn_cast<ArrayType>(SeqTy)) { - if (NumSkippedElements >= ArrTy->getNumElements()) - return nullptr; - } else if (VectorType *VecTy = dyn_cast<VectorType>(SeqTy)) { - if (NumSkippedElements >= VecTy->getNumElements()) - return nullptr; - } + if (NumSkippedElements >= SeqTy->getNumElements()) + return nullptr; Offset -= NumSkippedElements * ElementSize; // First check if we need to recurse. @@ -3456,63 +3457,60 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { // match relative to their starting offset. We have to verify this prior to // any rewriting. Stores.erase( - std::remove_if(Stores.begin(), Stores.end(), - [&UnsplittableLoads, &SplitOffsetsMap](StoreInst *SI) { - // Lookup the load we are storing in our map of split - // offsets. - auto *LI = cast<LoadInst>(SI->getValueOperand()); - // If it was completely unsplittable, then we're done, - // and this store can't be pre-split. - if (UnsplittableLoads.count(LI)) - return true; - - auto LoadOffsetsI = SplitOffsetsMap.find(LI); - if (LoadOffsetsI == SplitOffsetsMap.end()) - return false; // Unrelated loads are definitely safe. - auto &LoadOffsets = LoadOffsetsI->second; - - // Now lookup the store's offsets. - auto &StoreOffsets = SplitOffsetsMap[SI]; - - // If the relative offsets of each split in the load and - // store match exactly, then we can split them and we - // don't need to remove them here. - if (LoadOffsets.Splits == StoreOffsets.Splits) - return false; - - DEBUG(dbgs() - << " Mismatched splits for load and store:\n" - << " " << *LI << "\n" - << " " << *SI << "\n"); - - // We've found a store and load that we need to split - // with mismatched relative splits. Just give up on them - // and remove both instructions from our list of - // candidates. - UnsplittableLoads.insert(LI); - return true; - }), + remove_if(Stores, + [&UnsplittableLoads, &SplitOffsetsMap](StoreInst *SI) { + // Lookup the load we are storing in our map of split + // offsets. + auto *LI = cast<LoadInst>(SI->getValueOperand()); + // If it was completely unsplittable, then we're done, + // and this store can't be pre-split. + if (UnsplittableLoads.count(LI)) + return true; + + auto LoadOffsetsI = SplitOffsetsMap.find(LI); + if (LoadOffsetsI == SplitOffsetsMap.end()) + return false; // Unrelated loads are definitely safe. + auto &LoadOffsets = LoadOffsetsI->second; + + // Now lookup the store's offsets. + auto &StoreOffsets = SplitOffsetsMap[SI]; + + // If the relative offsets of each split in the load and + // store match exactly, then we can split them and we + // don't need to remove them here. + if (LoadOffsets.Splits == StoreOffsets.Splits) + return false; + + DEBUG(dbgs() << " Mismatched splits for load and store:\n" + << " " << *LI << "\n" + << " " << *SI << "\n"); + + // We've found a store and load that we need to split + // with mismatched relative splits. Just give up on them + // and remove both instructions from our list of + // candidates. + UnsplittableLoads.insert(LI); + return true; + }), Stores.end()); // Now we have to go *back* through all the stores, because a later store may // have caused an earlier store's load to become unsplittable and if it is // unsplittable for the later store, then we can't rely on it being split in // the earlier store either. - Stores.erase(std::remove_if(Stores.begin(), Stores.end(), - [&UnsplittableLoads](StoreInst *SI) { - auto *LI = - cast<LoadInst>(SI->getValueOperand()); - return UnsplittableLoads.count(LI); - }), + Stores.erase(remove_if(Stores, + [&UnsplittableLoads](StoreInst *SI) { + auto *LI = cast<LoadInst>(SI->getValueOperand()); + return UnsplittableLoads.count(LI); + }), Stores.end()); // Once we've established all the loads that can't be split for some reason, // filter any that made it into our list out. - Loads.erase(std::remove_if(Loads.begin(), Loads.end(), - [&UnsplittableLoads](LoadInst *LI) { - return UnsplittableLoads.count(LI); - }), + Loads.erase(remove_if(Loads, + [&UnsplittableLoads](LoadInst *LI) { + return UnsplittableLoads.count(LI); + }), Loads.end()); - // If no loads or stores are left, there is no pre-splitting to be done for // this alloca. if (Loads.empty() && Stores.empty()) @@ -3570,6 +3568,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { PartPtrTy, BasePtr->getName() + "."), getAdjustedAlignment(LI, PartOffset, DL), /*IsVolatile*/ false, LI->getName()); + PLoad->copyMetadata(*LI, LLVMContext::MD_mem_parallel_loop_access); // Append this load onto the list of split loads so we can find it later // to rewrite the stores. @@ -3622,7 +3621,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { APInt(DL.getPointerSizeInBits(), PartOffset), PartPtrTy, StoreBasePtr->getName() + "."), getAdjustedAlignment(SI, PartOffset, DL), /*IsVolatile*/ false); - (void)PStore; + PStore->copyMetadata(*LI, LLVMContext::MD_mem_parallel_loop_access); DEBUG(dbgs() << " +" << PartOffset << ":" << *PStore << "\n"); } @@ -3770,9 +3769,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { } // Remove the killed slices that have ben pre-split. - AS.erase(std::remove_if(AS.begin(), AS.end(), [](const Slice &S) { - return S.isDead(); - }), AS.end()); + AS.erase(remove_if(AS, [](const Slice &S) { return S.isDead(); }), AS.end()); // Insert our new slices. This will sort and merge them into the sorted // sequence. @@ -3787,8 +3784,8 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { // Finally, don't try to promote any allocas that new require re-splitting. // They have already been added to the worklist above. PromotableAllocas.erase( - std::remove_if( - PromotableAllocas.begin(), PromotableAllocas.end(), + remove_if( + PromotableAllocas, [&](AllocaInst *AI) { return ResplitPromotableAllocas.count(AI); }), PromotableAllocas.end()); @@ -3985,16 +3982,16 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { if (!IsSorted) std::sort(AS.begin(), AS.end()); - /// \brief Describes the allocas introduced by rewritePartition - /// in order to migrate the debug info. - struct Piece { + /// Describes the allocas introduced by rewritePartition in order to migrate + /// the debug info. + struct Fragment { AllocaInst *Alloca; uint64_t Offset; uint64_t Size; - Piece(AllocaInst *AI, uint64_t O, uint64_t S) + Fragment(AllocaInst *AI, uint64_t O, uint64_t S) : Alloca(AI), Offset(O), Size(S) {} }; - SmallVector<Piece, 4> Pieces; + SmallVector<Fragment, 4> Fragments; // Rewrite each partition. for (auto &P : AS.partitions()) { @@ -4005,7 +4002,7 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { uint64_t AllocaSize = DL.getTypeSizeInBits(NewAI->getAllocatedType()); // Don't include any padding. uint64_t Size = std::min(AllocaSize, P.size() * SizeOfByte); - Pieces.push_back(Piece(NewAI, P.beginOffset() * SizeOfByte, Size)); + Fragments.push_back(Fragment(NewAI, P.beginOffset() * SizeOfByte, Size)); } } ++NumPartitions; @@ -4022,32 +4019,34 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { auto *Expr = DbgDecl->getExpression(); DIBuilder DIB(*AI.getModule(), /*AllowUnresolved*/ false); uint64_t AllocaSize = DL.getTypeSizeInBits(AI.getAllocatedType()); - for (auto Piece : Pieces) { - // Create a piece expression describing the new partition or reuse AI's + for (auto Fragment : Fragments) { + // Create a fragment expression describing the new partition or reuse AI's // expression if there is only one partition. - auto *PieceExpr = Expr; - if (Piece.Size < AllocaSize || Expr->isBitPiece()) { + auto *FragmentExpr = Expr; + if (Fragment.Size < AllocaSize || Expr->isFragment()) { // If this alloca is already a scalar replacement of a larger aggregate, - // Piece.Offset describes the offset inside the scalar. - uint64_t Offset = Expr->isBitPiece() ? Expr->getBitPieceOffset() : 0; - uint64_t Start = Offset + Piece.Offset; - uint64_t Size = Piece.Size; - if (Expr->isBitPiece()) { - uint64_t AbsEnd = Expr->getBitPieceOffset() + Expr->getBitPieceSize(); + // Fragment.Offset describes the offset inside the scalar. + auto ExprFragment = Expr->getFragmentInfo(); + uint64_t Offset = ExprFragment ? ExprFragment->OffsetInBits : 0; + uint64_t Start = Offset + Fragment.Offset; + uint64_t Size = Fragment.Size; + if (ExprFragment) { + uint64_t AbsEnd = + ExprFragment->OffsetInBits + ExprFragment->SizeInBits; if (Start >= AbsEnd) // No need to describe a SROAed padding. continue; Size = std::min(Size, AbsEnd - Start); } - PieceExpr = DIB.createBitPieceExpression(Start, Size); + FragmentExpr = DIB.createFragmentExpression(Start, Size); } // Remove any existing dbg.declare intrinsic describing the same alloca. - if (DbgDeclareInst *OldDDI = FindAllocaDbgDeclare(Piece.Alloca)) + if (DbgDeclareInst *OldDDI = FindAllocaDbgDeclare(Fragment.Alloca)) OldDDI->eraseFromParent(); - DIB.insertDeclare(Piece.Alloca, Var, PieceExpr, DbgDecl->getDebugLoc(), - &AI); + DIB.insertDeclare(Fragment.Alloca, Var, FragmentExpr, + DbgDecl->getDebugLoc(), &AI); } } return Changed; @@ -4220,9 +4219,7 @@ PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT, auto IsInSet = [&](AllocaInst *AI) { return DeletedAllocas.count(AI); }; Worklist.remove_if(IsInSet); PostPromotionWorklist.remove_if(IsInSet); - PromotableAllocas.erase(std::remove_if(PromotableAllocas.begin(), - PromotableAllocas.end(), - IsInSet), + PromotableAllocas.erase(remove_if(PromotableAllocas, IsInSet), PromotableAllocas.end()); DeletedAllocas.clear(); } @@ -4244,7 +4241,7 @@ PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT, return PA; } -PreservedAnalyses SROA::run(Function &F, AnalysisManager<Function> &AM) { +PreservedAnalyses SROA::run(Function &F, FunctionAnalysisManager &AM) { return runImpl(F, AM.getResult<DominatorTreeAnalysis>(F), AM.getResult<AssumptionAnalysis>(F)); } @@ -4277,7 +4274,7 @@ public: AU.setPreservesCFG(); } - const char *getPassName() const override { return "SROA"; } + StringRef getPassName() const override { return "SROA"; } static char ID; }; diff --git a/contrib/llvm/lib/Transforms/Scalar/Scalar.cpp b/contrib/llvm/lib/Transforms/Scalar/Scalar.cpp index f235b12..afe7483 100644 --- a/contrib/llvm/lib/Transforms/Scalar/Scalar.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/Scalar.cpp @@ -43,14 +43,17 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeDSELegacyPassPass(Registry); initializeGuardWideningLegacyPassPass(Registry); initializeGVNLegacyPassPass(Registry); + initializeNewGVNPass(Registry); initializeEarlyCSELegacyPassPass(Registry); + initializeEarlyCSEMemSSALegacyPassPass(Registry); initializeGVNHoistLegacyPassPass(Registry); initializeFlattenCFGPassPass(Registry); initializeInductiveRangeCheckEliminationPass(Registry); initializeIndVarSimplifyLegacyPassPass(Registry); initializeJumpThreadingPass(Registry); initializeLegacyLICMPassPass(Registry); - initializeLoopDataPrefetchPass(Registry); + initializeLegacyLoopSinkPassPass(Registry); + initializeLoopDataPrefetchLegacyPassPass(Registry); initializeLoopDeletionLegacyPassPass(Registry); initializeLoopAccessLegacyAnalysisPass(Registry); initializeLoopInstSimplifyLegacyPassPass(Registry); @@ -64,10 +67,10 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeLoopIdiomRecognizeLegacyPassPass(Registry); initializeLowerAtomicLegacyPassPass(Registry); initializeLowerExpectIntrinsicPass(Registry); - initializeLowerGuardIntrinsicPass(Registry); + initializeLowerGuardIntrinsicLegacyPassPass(Registry); initializeMemCpyOptLegacyPassPass(Registry); initializeMergedLoadStoreMotionLegacyPassPass(Registry); - initializeNaryReassociatePass(Registry); + initializeNaryReassociateLegacyPassPass(Registry); initializePartiallyInlineLibCallsLegacyPassPass(Registry); initializeReassociateLegacyPassPass(Registry); initializeRegToMemPass(Registry); @@ -80,7 +83,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeSinkingLegacyPassPass(Registry); initializeTailCallElimPass(Registry); initializeSeparateConstOffsetFromGEPPass(Registry); - initializeSpeculativeExecutionPass(Registry); + initializeSpeculativeExecutionLegacyPassPass(Registry); initializeStraightLineStrengthReducePass(Registry); initializeLoadCombinePass(Registry); initializePlaceBackedgeSafepointsImplPass(Registry); @@ -124,6 +127,10 @@ void LLVMAddGVNPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createGVNPass()); } +void LLVMAddNewGVNPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createNewGVNPass()); +} + void LLVMAddMergedLoadStoreMotionPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createMergedLoadStoreMotionPass()); } @@ -140,6 +147,10 @@ void LLVMAddJumpThreadingPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createJumpThreadingPass()); } +void LLVMAddLoopSinkPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopSinkPass()); +} + void LLVMAddLICMPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createLICMPass()); } @@ -234,7 +245,11 @@ void LLVMAddCorrelatedValuePropagationPass(LLVMPassManagerRef PM) { } void LLVMAddEarlyCSEPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createEarlyCSEPass()); + unwrap(PM)->add(createEarlyCSEPass(false/*=UseMemorySSA*/)); +} + +void LLVMAddEarlyCSEMemSSAPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createEarlyCSEPass(true/*=UseMemorySSA*/)); } void LLVMAddGVNHoistLegacyPass(LLVMPassManagerRef PM) { diff --git a/contrib/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/contrib/llvm/lib/Transforms/Scalar/Scalarizer.cpp index aed4a4a..39969e2 100644 --- a/contrib/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -16,6 +16,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/Pass.h" @@ -148,6 +149,7 @@ public: bool visitPHINode(PHINode &); bool visitLoadInst(LoadInst &); bool visitStoreInst(StoreInst &); + bool visitCallInst(CallInst &I); static void registerOptions() { // This is disabled by default because having separate loads and stores @@ -169,6 +171,8 @@ private: template<typename T> bool splitBinary(Instruction &, const T &); + bool splitCall(CallInst &CI); + ScatterMap Scattered; GatherList Gathered; unsigned ParallelLoopAccessMDKind; @@ -394,6 +398,77 @@ bool Scalarizer::splitBinary(Instruction &I, const Splitter &Split) { return true; } +static bool isTriviallyScalariable(Intrinsic::ID ID) { + return isTriviallyVectorizable(ID); +} + +// All of the current scalarizable intrinsics only have one mangled type. +static Function *getScalarIntrinsicDeclaration(Module *M, + Intrinsic::ID ID, + VectorType *Ty) { + return Intrinsic::getDeclaration(M, ID, { Ty->getScalarType() }); +} + +/// If a call to a vector typed intrinsic function, split into a scalar call per +/// element if possible for the intrinsic. +bool Scalarizer::splitCall(CallInst &CI) { + VectorType *VT = dyn_cast<VectorType>(CI.getType()); + if (!VT) + return false; + + Function *F = CI.getCalledFunction(); + if (!F) + return false; + + Intrinsic::ID ID = F->getIntrinsicID(); + if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID)) + return false; + + unsigned NumElems = VT->getNumElements(); + unsigned NumArgs = CI.getNumArgOperands(); + + ValueVector ScalarOperands(NumArgs); + SmallVector<Scatterer, 8> Scattered(NumArgs); + + Scattered.resize(NumArgs); + + // Assumes that any vector type has the same number of elements as the return + // vector type, which is true for all current intrinsics. + for (unsigned I = 0; I != NumArgs; ++I) { + Value *OpI = CI.getOperand(I); + if (OpI->getType()->isVectorTy()) { + Scattered[I] = scatter(&CI, OpI); + assert(Scattered[I].size() == NumElems && "mismatched call operands"); + } else { + ScalarOperands[I] = OpI; + } + } + + ValueVector Res(NumElems); + ValueVector ScalarCallOps(NumArgs); + + Function *NewIntrin = getScalarIntrinsicDeclaration(F->getParent(), ID, VT); + IRBuilder<> Builder(&CI); + + // Perform actual scalarization, taking care to preserve any scalar operands. + for (unsigned Elem = 0; Elem < NumElems; ++Elem) { + ScalarCallOps.clear(); + + for (unsigned J = 0; J != NumArgs; ++J) { + if (hasVectorInstrinsicScalarOpd(ID, J)) + ScalarCallOps.push_back(ScalarOperands[J]); + else + ScalarCallOps.push_back(Scattered[J][Elem]); + } + + Res[Elem] = Builder.CreateCall(NewIntrin, ScalarCallOps, + CI.getName() + ".i" + Twine(Elem)); + } + + gather(&CI, Res); + return true; +} + bool Scalarizer::visitSelectInst(SelectInst &SI) { VectorType *VT = dyn_cast<VectorType>(SI.getType()); if (!VT) @@ -642,6 +717,10 @@ bool Scalarizer::visitStoreInst(StoreInst &SI) { return true; } +bool Scalarizer::visitCallInst(CallInst &CI) { + return splitCall(CI); +} + // Delete the instructions that we scalarized. If a full vector result // is still needed, recreate it using InsertElements. bool Scalarizer::finish() { diff --git a/contrib/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/contrib/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index d6ae186..4d59453 100644 --- a/contrib/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -722,7 +722,7 @@ bool SeparateConstOffsetFromGEP::canonicalizeArrayIndicesToPointerSize( for (User::op_iterator I = GEP->op_begin() + 1, E = GEP->op_end(); I != E; ++I, ++GTI) { // Skip struct member indices which must be i32. - if (isa<SequentialType>(*GTI)) { + if (GTI.isSequential()) { if ((*I)->getType() != IntPtrTy) { *I = CastInst::CreateIntegerCast(*I, IntPtrTy, true, "idxprom", GEP); Changed = true; @@ -739,7 +739,7 @@ SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP, int64_t AccumulativeByteOffset = 0; gep_type_iterator GTI = gep_type_begin(*GEP); for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { - if (isa<SequentialType>(*GTI)) { + if (GTI.isSequential()) { // Tries to extract a constant offset from this GEP index. int64_t ConstantOffset = ConstantOffsetExtractor::Find(GEP->getOperand(I), GEP, DT); @@ -752,7 +752,7 @@ SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP, ConstantOffset * DL->getTypeAllocSize(GTI.getIndexedType()); } } else if (LowerGEP) { - StructType *StTy = cast<StructType>(*GTI); + StructType *StTy = GTI.getStructType(); uint64_t Field = cast<ConstantInt>(GEP->getOperand(I))->getZExtValue(); // Skip field 0 as the offset is always 0. if (Field != 0) { @@ -787,7 +787,7 @@ void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs( // Create an ugly GEP for each sequential index. We don't create GEPs for // structure indices, as they are accumulated in the constant offset index. for (unsigned I = 1, E = Variadic->getNumOperands(); I != E; ++I, ++GTI) { - if (isa<SequentialType>(*GTI)) { + if (GTI.isSequential()) { Value *Idx = Variadic->getOperand(I); // Skip zero indices. if (ConstantInt *CI = dyn_cast<ConstantInt>(Idx)) @@ -848,7 +848,7 @@ SeparateConstOffsetFromGEP::lowerToArithmetics(GetElementPtrInst *Variadic, // don't create arithmetics for structure indices, as they are accumulated // in the constant offset index. for (unsigned I = 1, E = Variadic->getNumOperands(); I != E; ++I, ++GTI) { - if (isa<SequentialType>(*GTI)) { + if (GTI.isSequential()) { Value *Idx = Variadic->getOperand(I); // Skip zero indices. if (ConstantInt *CI = dyn_cast<ConstantInt>(Idx)) @@ -928,7 +928,7 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { // handle the constant offset and won't need a new structure index. gep_type_iterator GTI = gep_type_begin(*GEP); for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { - if (isa<SequentialType>(*GTI)) { + if (GTI.isSequential()) { // Splits this GEP index into a variadic part and a constant offset, and // uses the variadic part as the new index. Value *OldIdx = GEP->getOperand(I); @@ -1150,8 +1150,7 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) { bool SeparateConstOffsetFromGEP::reuniteExts(Function &F) { bool Changed = false; DominatingExprs.clear(); - for (auto Node = GraphTraits<DominatorTree *>::nodes_begin(DT); - Node != GraphTraits<DominatorTree *>::nodes_end(DT); ++Node) { + for (const auto Node : depth_first(DT)) { BasicBlock *BB = Node->getBlock(); for (auto I = BB->begin(); I != BB->end(); ) { Instruction *Cur = &*I++; diff --git a/contrib/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/contrib/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp index 2d0a21d..f2723bd 100644 --- a/contrib/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -187,7 +187,7 @@ SimplifyCFGPass::SimplifyCFGPass(int BonusInstThreshold) : BonusInstThreshold(BonusInstThreshold) {} PreservedAnalyses SimplifyCFGPass::run(Function &F, - AnalysisManager<Function> &AM) { + FunctionAnalysisManager &AM) { auto &TTI = AM.getResult<TargetIRAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); diff --git a/contrib/llvm/lib/Transforms/Scalar/Sink.cpp b/contrib/llvm/lib/Transforms/Scalar/Sink.cpp index d9a296c..c3f14a0 100644 --- a/contrib/llvm/lib/Transforms/Scalar/Sink.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/Sink.cpp @@ -254,7 +254,7 @@ static bool iterativelySinkInstructions(Function &F, DominatorTree &DT, return EverMadeChange; } -PreservedAnalyses SinkingPass::run(Function &F, AnalysisManager<Function> &AM) { +PreservedAnalyses SinkingPass::run(Function &F, FunctionAnalysisManager &AM) { auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &LI = AM.getResult<LoopAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); diff --git a/contrib/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp b/contrib/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp index 9bf2d62..a7c308b 100644 --- a/contrib/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp @@ -61,9 +61,9 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/SpeculativeExecution.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -101,58 +101,62 @@ static cl::opt<bool> SpecExecOnlyIfDivergentTarget( namespace { -class SpeculativeExecution : public FunctionPass { - public: - static char ID; - explicit SpeculativeExecution(bool OnlyIfDivergentTarget = false) - : FunctionPass(ID), - OnlyIfDivergentTarget(OnlyIfDivergentTarget || - SpecExecOnlyIfDivergentTarget) {} - - void getAnalysisUsage(AnalysisUsage &AU) const override; - bool runOnFunction(Function &F) override; - - const char *getPassName() const override { - if (OnlyIfDivergentTarget) - return "Speculatively execute instructions if target has divergent " - "branches"; - return "Speculatively execute instructions"; - } - - private: - bool runOnBasicBlock(BasicBlock &B); - bool considerHoistingFromTo(BasicBlock &FromBlock, BasicBlock &ToBlock); - - // If true, this pass is a nop unless the target architecture has branch - // divergence. +class SpeculativeExecutionLegacyPass : public FunctionPass { +public: + static char ID; + explicit SpeculativeExecutionLegacyPass(bool OnlyIfDivergentTarget = false) + : FunctionPass(ID), OnlyIfDivergentTarget(OnlyIfDivergentTarget || + SpecExecOnlyIfDivergentTarget), + Impl(OnlyIfDivergentTarget) {} + + void getAnalysisUsage(AnalysisUsage &AU) const override; + bool runOnFunction(Function &F) override; + + StringRef getPassName() const override { + if (OnlyIfDivergentTarget) + return "Speculatively execute instructions if target has divergent " + "branches"; + return "Speculatively execute instructions"; + } + +private: + // Variable preserved purely for correct name printing. const bool OnlyIfDivergentTarget; - const TargetTransformInfo *TTI = nullptr; + + SpeculativeExecutionPass Impl; }; } // namespace -char SpeculativeExecution::ID = 0; -INITIALIZE_PASS_BEGIN(SpeculativeExecution, "speculative-execution", +char SpeculativeExecutionLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(SpeculativeExecutionLegacyPass, "speculative-execution", "Speculatively execute instructions", false, false) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(SpeculativeExecution, "speculative-execution", +INITIALIZE_PASS_END(SpeculativeExecutionLegacyPass, "speculative-execution", "Speculatively execute instructions", false, false) -void SpeculativeExecution::getAnalysisUsage(AnalysisUsage &AU) const { +void SpeculativeExecutionLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<TargetTransformInfoWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } -bool SpeculativeExecution::runOnFunction(Function &F) { +bool SpeculativeExecutionLegacyPass::runOnFunction(Function &F) { if (skipFunction(F)) return false; - TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + return Impl.runImpl(F, TTI); +} + +namespace llvm { + +bool SpeculativeExecutionPass::runImpl(Function &F, TargetTransformInfo *TTI) { if (OnlyIfDivergentTarget && !TTI->hasBranchDivergence()) { DEBUG(dbgs() << "Not running SpeculativeExecution because " "TTI->hasBranchDivergence() is false.\n"); return false; } + this->TTI = TTI; bool Changed = false; for (auto& B : F) { Changed |= runOnBasicBlock(B); @@ -160,7 +164,7 @@ bool SpeculativeExecution::runOnFunction(Function &F) { return Changed; } -bool SpeculativeExecution::runOnBasicBlock(BasicBlock &B) { +bool SpeculativeExecutionPass::runOnBasicBlock(BasicBlock &B) { BranchInst *BI = dyn_cast<BranchInst>(B.getTerminator()); if (BI == nullptr) return false; @@ -220,6 +224,24 @@ static unsigned ComputeSpeculationCost(const Instruction *I, case Instruction::Xor: case Instruction::ZExt: case Instruction::SExt: + case Instruction::Call: + case Instruction::BitCast: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::AddrSpaceCast: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPExt: + case Instruction::FPTrunc: + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: + case Instruction::FDiv: + case Instruction::FRem: + case Instruction::ICmp: + case Instruction::FCmp: return TTI.getUserCost(I); default: @@ -227,8 +249,8 @@ static unsigned ComputeSpeculationCost(const Instruction *I, } } -bool SpeculativeExecution::considerHoistingFromTo(BasicBlock &FromBlock, - BasicBlock &ToBlock) { +bool SpeculativeExecutionPass::considerHoistingFromTo( + BasicBlock &FromBlock, BasicBlock &ToBlock) { SmallSet<const Instruction *, 8> NotHoisted; const auto AllPrecedingUsesFromBlockHoisted = [&NotHoisted](User *U) { for (Value* V : U->operand_values()) { @@ -270,14 +292,28 @@ bool SpeculativeExecution::considerHoistingFromTo(BasicBlock &FromBlock, return true; } -namespace llvm { - FunctionPass *createSpeculativeExecutionPass() { - return new SpeculativeExecution(); + return new SpeculativeExecutionLegacyPass(); } FunctionPass *createSpeculativeExecutionIfHasBranchDivergencePass() { - return new SpeculativeExecution(/* OnlyIfDivergentTarget = */ true); + return new SpeculativeExecutionLegacyPass(/* OnlyIfDivergentTarget = */ true); } +SpeculativeExecutionPass::SpeculativeExecutionPass(bool OnlyIfDivergentTarget) + : OnlyIfDivergentTarget(OnlyIfDivergentTarget || + SpecExecOnlyIfDivergentTarget) {} + +PreservedAnalyses SpeculativeExecutionPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto *TTI = &AM.getResult<TargetIRAnalysis>(F); + + bool Changed = runImpl(F, TTI); + + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + return PA; +} } // namespace llvm diff --git a/contrib/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/contrib/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp index 292d040..2be3f5c 100644 --- a/contrib/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -55,8 +55,6 @@ // // - When (i' - i) is constant but i and i' are not, we could still perform // SLSR. -#include <vector> - #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -68,6 +66,8 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" +#include <list> +#include <vector> using namespace llvm; using namespace PatternMatch; @@ -80,7 +80,7 @@ class StraightLineStrengthReduce : public FunctionPass { public: // SLSR candidate. Such a candidate must be in one of the forms described in // the header comments. - struct Candidate : public ilist_node<Candidate> { + struct Candidate { enum Kind { Invalid, // reserved for the default constructor Add, // B + i * S @@ -200,7 +200,7 @@ private: DominatorTree *DT; ScalarEvolution *SE; TargetTransformInfo *TTI; - ilist<Candidate> Candidates; + std::list<Candidate> Candidates; // Temporarily holds all instructions that are unlinked (but not deleted) by // rewriteCandidateWithBasis. These instructions will be actually removed // after all rewriting finishes. @@ -490,8 +490,8 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP( IndexExprs.push_back(SE->getSCEV(*I)); gep_type_iterator GTI = gep_type_begin(GEP); - for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I) { - if (!isa<SequentialType>(*GTI++)) + for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { + if (GTI.isStruct()) continue; const SCEV *OrigIndexExpr = IndexExprs[I - 1]; @@ -499,11 +499,9 @@ void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP( // The base of this candidate is GEP's base plus the offsets of all // indices except this current one. - const SCEV *BaseExpr = SE->getGEPExpr(GEP->getSourceElementType(), - SE->getSCEV(GEP->getPointerOperand()), - IndexExprs, GEP->isInBounds()); + const SCEV *BaseExpr = SE->getGEPExpr(cast<GEPOperator>(GEP), IndexExprs); Value *ArrayIdx = GEP->getOperand(I); - uint64_t ElementSize = DL->getTypeAllocSize(*GTI); + uint64_t ElementSize = DL->getTypeAllocSize(GTI.getIndexedType()); if (ArrayIdx->getType()->getIntegerBitWidth() <= DL->getPointerSizeInBits(GEP->getAddressSpace())) { // Skip factoring if ArrayIdx is wider than the pointer size, because @@ -674,11 +672,9 @@ bool StraightLineStrengthReduce::runOnFunction(Function &F) { SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); // Traverse the dominator tree in the depth-first order. This order makes sure // all bases of a candidate are in Candidates when we process it. - for (auto node = GraphTraits<DominatorTree *>::nodes_begin(DT); - node != GraphTraits<DominatorTree *>::nodes_end(DT); ++node) { - for (auto &I : *node->getBlock()) + for (const auto Node : depth_first(DT)) + for (auto &I : *(Node->getBlock())) allocateCandidatesAndFindBasis(&I); - } // Rewrite candidates in the reverse depth-first order. This order makes sure // a candidate being rewritten is not a basis for any other candidate. diff --git a/contrib/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp b/contrib/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp index e9ac39b..49ce026 100644 --- a/contrib/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -43,77 +43,58 @@ typedef SmallPtrSet<BasicBlock *, 8> BBSet; typedef MapVector<PHINode *, BBValueVector> PhiMap; typedef MapVector<BasicBlock *, BBVector> BB2BBVecMap; -typedef DenseMap<DomTreeNode *, unsigned> DTN2UnsignedMap; typedef DenseMap<BasicBlock *, PhiMap> BBPhiMap; typedef DenseMap<BasicBlock *, Value *> BBPredicates; typedef DenseMap<BasicBlock *, BBPredicates> PredMap; typedef DenseMap<BasicBlock *, BasicBlock*> BB2BBMap; // The name for newly created blocks. - static const char *const FlowBlockName = "Flow"; -/// @brief Find the nearest common dominator for multiple BasicBlocks +/// Finds the nearest common dominator of a set of BasicBlocks. /// -/// Helper class for StructurizeCFG -/// TODO: Maybe move into common code +/// For every BB you add to the set, you can specify whether we "remember" the +/// block. When you get the common dominator, you can also ask whether it's one +/// of the blocks we remembered. class NearestCommonDominator { DominatorTree *DT; + BasicBlock *Result = nullptr; + bool ResultIsRemembered = false; - DTN2UnsignedMap IndexMap; - - BasicBlock *Result; - unsigned ResultIndex; - bool ExplicitMentioned; - -public: - /// \brief Start a new query - NearestCommonDominator(DominatorTree *DomTree) { - DT = DomTree; - Result = nullptr; - } - - /// \brief Add BB to the resulting dominator - void addBlock(BasicBlock *BB, bool Remember = true) { - DomTreeNode *Node = DT->getNode(BB); - + /// Add BB to the resulting dominator. + void addBlock(BasicBlock *BB, bool Remember) { if (!Result) { - unsigned Numbering = 0; - for (;Node;Node = Node->getIDom()) - IndexMap[Node] = ++Numbering; Result = BB; - ResultIndex = 1; - ExplicitMentioned = Remember; + ResultIsRemembered = Remember; return; } - for (;Node;Node = Node->getIDom()) - if (IndexMap.count(Node)) - break; - else - IndexMap[Node] = 0; + BasicBlock *NewResult = DT->findNearestCommonDominator(Result, BB); + if (NewResult != Result) + ResultIsRemembered = false; + if (NewResult == BB) + ResultIsRemembered |= Remember; + Result = NewResult; + } - assert(Node && "Dominator tree invalid!"); +public: + explicit NearestCommonDominator(DominatorTree *DomTree) : DT(DomTree) {} - unsigned Numbering = IndexMap[Node]; - if (Numbering > ResultIndex) { - Result = Node->getBlock(); - ResultIndex = Numbering; - ExplicitMentioned = Remember && (Result == BB); - } else if (Numbering == ResultIndex) { - ExplicitMentioned |= Remember; - } + void addBlock(BasicBlock *BB) { + addBlock(BB, /* Remember = */ false); } - /// \brief Is "Result" one of the BBs added with "Remember" = True? - bool wasResultExplicitMentioned() { - return ExplicitMentioned; + void addAndRememberBlock(BasicBlock *BB) { + addBlock(BB, /* Remember = */ true); } - /// \brief Get the query result - BasicBlock *getResult() { - return Result; - } + /// Get the nearest common dominator of all the BBs added via addBlock() and + /// addAndRememberBlock(). + BasicBlock *result() { return Result; } + + /// Is the BB returned by getResult() one of the blocks we added to the set + /// with addAndRememberBlock()? + bool resultIsRememberedBlock() { return ResultIsRemembered; } }; /// @brief Transforms the control flow graph on one single entry/exit region @@ -141,7 +122,7 @@ public: /// Control flow is expressed as a branch where the true exit goes into the /// "Then"/"Else" region, while the false exit skips the region /// The condition for the optional "Else" region is expressed as a PHI node. -/// The incomming values of the PHI node are true for the "If" edge and false +/// The incoming values of the PHI node are true for the "If" edge and false /// for the "Then" edge. /// /// Additionally to that even complicated loops look like this: @@ -163,7 +144,6 @@ public: /// breaks and the false values expresses continue states. class StructurizeCFG : public RegionPass { bool SkipUniformRegions; - DivergenceAnalysis *DA; Type *Boolean; ConstantInt *BoolTrue; @@ -176,7 +156,7 @@ class StructurizeCFG : public RegionPass { DominatorTree *DT; LoopInfo *LI; - RNVector Order; + SmallVector<RegionNode *, 8> Order; BBSet Visited; BBPhiMap DeletedPhis; @@ -236,29 +216,19 @@ class StructurizeCFG : public RegionPass { void rebuildSSA(); - bool hasOnlyUniformBranches(const Region *R); - public: static char ID; - StructurizeCFG() : - RegionPass(ID), SkipUniformRegions(false) { - initializeStructurizeCFGPass(*PassRegistry::getPassRegistry()); - } - - StructurizeCFG(bool SkipUniformRegions) : - RegionPass(ID), SkipUniformRegions(SkipUniformRegions) { + explicit StructurizeCFG(bool SkipUniformRegions = false) + : RegionPass(ID), SkipUniformRegions(SkipUniformRegions) { initializeStructurizeCFGPass(*PassRegistry::getPassRegistry()); } - using Pass::doInitialization; bool doInitialization(Region *R, RGPassManager &RGM) override; bool runOnRegion(Region *R, RGPassManager &RGM) override; - const char *getPassName() const override { - return "Structurize control flow"; - } + StringRef getPassName() const override { return "Structurize control flow"; } void getAnalysisUsage(AnalysisUsage &AU) const override { if (SkipUniformRegions) @@ -266,6 +236,7 @@ public: AU.addRequiredID(LowerSwitchID); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); RegionPass::getAnalysisUsage(AU); } @@ -298,17 +269,13 @@ bool StructurizeCFG::doInitialization(Region *R, RGPassManager &RGM) { /// \brief Build up the general order of nodes void StructurizeCFG::orderNodes() { - RNVector TempOrder; ReversePostOrderTraversal<Region*> RPOT(ParentRegion); - TempOrder.append(RPOT.begin(), RPOT.end()); - - std::map<Loop*, unsigned> LoopBlocks; - + SmallDenseMap<Loop*, unsigned, 8> LoopBlocks; // The reverse post-order traversal of the list gives us an ordering close // to what we want. The only problem with it is that sometimes backedges // for outer loops will be visited before backedges for inner loops. - for (RegionNode *RN : TempOrder) { + for (RegionNode *RN : RPOT) { BasicBlock *BB = RN->getEntry(); Loop *Loop = LI->getLoopFor(BB); ++LoopBlocks[Loop]; @@ -316,19 +283,18 @@ void StructurizeCFG::orderNodes() { unsigned CurrentLoopDepth = 0; Loop *CurrentLoop = nullptr; - BBSet TempVisited; - for (RNVector::iterator I = TempOrder.begin(), E = TempOrder.end(); I != E; ++I) { + for (auto I = RPOT.begin(), E = RPOT.end(); I != E; ++I) { BasicBlock *BB = (*I)->getEntry(); unsigned LoopDepth = LI->getLoopDepth(BB); - if (std::find(Order.begin(), Order.end(), *I) != Order.end()) + if (is_contained(Order, *I)) continue; if (LoopDepth < CurrentLoopDepth) { // Make sure we have visited all blocks in this loop before moving back to // the outer loop. - RNVector::iterator LoopI = I; + auto LoopI = I; while (unsigned &BlockCount = LoopBlocks[CurrentLoop]) { LoopI++; BasicBlock *LoopBB = (*LoopI)->getEntry(); @@ -340,9 +306,8 @@ void StructurizeCFG::orderNodes() { } CurrentLoop = LI->getLoopFor(BB); - if (CurrentLoop) { + if (CurrentLoop) LoopBlocks[CurrentLoop]--; - } CurrentLoopDepth = LoopDepth; Order.push_back(*I); @@ -426,46 +391,40 @@ void StructurizeCFG::gatherPredicates(RegionNode *N) { BBPredicates &Pred = Predicates[BB]; BBPredicates &LPred = LoopPreds[BB]; - for (pred_iterator PI = pred_begin(BB), PE = pred_end(BB); - PI != PE; ++PI) { - + for (BasicBlock *P : predecessors(BB)) { // Ignore it if it's a branch from outside into our region entry - if (!ParentRegion->contains(*PI)) + if (!ParentRegion->contains(P)) continue; - Region *R = RI->getRegionFor(*PI); + Region *R = RI->getRegionFor(P); if (R == ParentRegion) { - // It's a top level block in our region - BranchInst *Term = cast<BranchInst>((*PI)->getTerminator()); + BranchInst *Term = cast<BranchInst>(P->getTerminator()); for (unsigned i = 0, e = Term->getNumSuccessors(); i != e; ++i) { BasicBlock *Succ = Term->getSuccessor(i); if (Succ != BB) continue; - if (Visited.count(*PI)) { + if (Visited.count(P)) { // Normal forward edge if (Term->isConditional()) { // Try to treat it like an ELSE block BasicBlock *Other = Term->getSuccessor(!i); if (Visited.count(Other) && !Loops.count(Other) && - !Pred.count(Other) && !Pred.count(*PI)) { + !Pred.count(Other) && !Pred.count(P)) { Pred[Other] = BoolFalse; - Pred[*PI] = BoolTrue; + Pred[P] = BoolTrue; continue; } } - Pred[*PI] = buildCondition(Term, i, false); - + Pred[P] = buildCondition(Term, i, false); } else { // Back edge - LPred[*PI] = buildCondition(Term, i, true); + LPred[P] = buildCondition(Term, i, true); } } - } else { - // It's an exit from a sub region while (R->getParent() != ParentRegion) R = R->getParent(); @@ -496,7 +455,6 @@ void StructurizeCFG::collectInfos() { Visited.clear(); for (RegionNode *RN : reverse(Order)) { - DEBUG(dbgs() << "Visiting: " << (RN->isSubRegion() ? "SubRegion with entry: " : "") << RN->getEntry()->getName() << " Loop Depth: " @@ -533,25 +491,26 @@ void StructurizeCFG::insertConditions(bool Loops) { BBPredicates &Preds = Loops ? LoopPreds[SuccFalse] : Predicates[SuccTrue]; NearestCommonDominator Dominator(DT); - Dominator.addBlock(Parent, false); + Dominator.addBlock(Parent); Value *ParentValue = nullptr; - for (BBPredicates::iterator PI = Preds.begin(), PE = Preds.end(); - PI != PE; ++PI) { + for (std::pair<BasicBlock *, Value *> BBAndPred : Preds) { + BasicBlock *BB = BBAndPred.first; + Value *Pred = BBAndPred.second; - if (PI->first == Parent) { - ParentValue = PI->second; + if (BB == Parent) { + ParentValue = Pred; break; } - PhiInserter.AddAvailableValue(PI->first, PI->second); - Dominator.addBlock(PI->first); + PhiInserter.AddAvailableValue(BB, Pred); + Dominator.addAndRememberBlock(BB); } if (ParentValue) { Term->setCondition(ParentValue); } else { - if (!Dominator.wasResultExplicitMentioned()) - PhiInserter.AddAvailableValue(Dominator.getResult(), Default); + if (!Dominator.resultIsRememberedBlock()) + PhiInserter.AddAvailableValue(Dominator.result(), Default); Term->setCondition(PhiInserter.GetValueInMiddleOfBlock(Parent)); } @@ -562,10 +521,10 @@ void StructurizeCFG::insertConditions(bool Loops) { /// them in DeletedPhis void StructurizeCFG::delPhiValues(BasicBlock *From, BasicBlock *To) { PhiMap &Map = DeletedPhis[To]; - for (BasicBlock::iterator I = To->begin(), E = To->end(); - I != E && isa<PHINode>(*I);) { - - PHINode &Phi = cast<PHINode>(*I++); + for (Instruction &I : *To) { + if (!isa<PHINode>(I)) + break; + PHINode &Phi = cast<PHINode>(I); while (Phi.getBasicBlockIndex(From) != -1) { Value *Deleted = Phi.removeIncomingValue(From, false); Map[&Phi].push_back(std::make_pair(From, Deleted)); @@ -575,10 +534,10 @@ void StructurizeCFG::delPhiValues(BasicBlock *From, BasicBlock *To) { /// \brief Add a dummy PHI value as soon as we knew the new predecessor void StructurizeCFG::addPhiValues(BasicBlock *From, BasicBlock *To) { - for (BasicBlock::iterator I = To->begin(), E = To->end(); - I != E && isa<PHINode>(*I);) { - - PHINode &Phi = cast<PHINode>(*I++); + for (Instruction &I : *To) { + if (!isa<PHINode>(I)) + break; + PHINode &Phi = cast<PHINode>(I); Value *Undef = UndefValue::get(Phi.getType()); Phi.addIncoming(Undef, From); } @@ -589,7 +548,6 @@ void StructurizeCFG::addPhiValues(BasicBlock *From, BasicBlock *To) { void StructurizeCFG::setPhiValues() { SSAUpdater Updater; for (const auto &AddedPhi : AddedPhis) { - BasicBlock *To = AddedPhi.first; const BBVector &From = AddedPhi.second; @@ -598,7 +556,6 @@ void StructurizeCFG::setPhiValues() { PhiMap &Map = DeletedPhis[To]; for (const auto &PI : Map) { - PHINode *Phi = PI.first; Value *Undef = UndefValue::get(Phi->getType()); Updater.Initialize(Phi->getType(), ""); @@ -606,18 +563,16 @@ void StructurizeCFG::setPhiValues() { Updater.AddAvailableValue(To, Undef); NearestCommonDominator Dominator(DT); - Dominator.addBlock(To, false); + Dominator.addBlock(To); for (const auto &VI : PI.second) { - Updater.AddAvailableValue(VI.first, VI.second); - Dominator.addBlock(VI.first); + Dominator.addAndRememberBlock(VI.first); } - if (!Dominator.wasResultExplicitMentioned()) - Updater.AddAvailableValue(Dominator.getResult(), Undef); + if (!Dominator.resultIsRememberedBlock()) + Updater.AddAvailableValue(Dominator.result(), Undef); for (BasicBlock *FI : From) { - int Idx = Phi->getBasicBlockIndex(FI); assert(Idx != -1); Phi->setIncomingValue(Idx, Updater.GetValueAtEndOfBlock(FI)); @@ -636,10 +591,8 @@ void StructurizeCFG::killTerminator(BasicBlock *BB) { return; for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); - SI != SE; ++SI) { - + SI != SE; ++SI) delPhiValues(BB, *SI); - } Term->eraseFromParent(); } @@ -653,10 +606,10 @@ void StructurizeCFG::changeExit(RegionNode *Node, BasicBlock *NewExit, BasicBlock *Dominator = nullptr; // Find all the edges from the sub region to the exit - for (pred_iterator I = pred_begin(OldExit), E = pred_end(OldExit); - I != E;) { + for (auto BBI = pred_begin(OldExit), E = pred_end(OldExit); BBI != E;) { + // Incrememt BBI before mucking with BB's terminator. + BasicBlock *BB = *BBI++; - BasicBlock *BB = *I++; if (!SubRegion->contains(BB)) continue; @@ -680,7 +633,6 @@ void StructurizeCFG::changeExit(RegionNode *Node, BasicBlock *NewExit, // Update the region info SubRegion->replaceExit(NewExit); - } else { BasicBlock *BB = Node->getNodeAs<BasicBlock>(); killTerminator(BB); @@ -711,7 +663,6 @@ BasicBlock *StructurizeCFG::needPrefix(bool NeedEmpty) { killTerminator(Entry); if (!NeedEmpty || Entry->getFirstInsertionPt() == Entry->end()) return Entry; - } // create a new flow node @@ -726,13 +677,13 @@ BasicBlock *StructurizeCFG::needPrefix(bool NeedEmpty) { /// \brief Returns the region exit if possible, otherwise just a new flow node BasicBlock *StructurizeCFG::needPostfix(BasicBlock *Flow, bool ExitUseAllowed) { - if (Order.empty() && ExitUseAllowed) { - BasicBlock *Exit = ParentRegion->getExit(); - DT->changeImmediateDominator(Exit, Flow); - addPhiValues(Flow, Exit); - return Exit; - } - return getNextFlow(Flow); + if (!Order.empty() || !ExitUseAllowed) + return getNextFlow(Flow); + + BasicBlock *Exit = ParentRegion->getExit(); + DT->changeImmediateDominator(Exit, Flow); + addPhiValues(Flow, Exit); + return Exit; } /// \brief Set the previous node @@ -741,16 +692,12 @@ void StructurizeCFG::setPrevNode(BasicBlock *BB) { : nullptr; } -/// \brief Does BB dominate all the predicates of Node ? +/// \brief Does BB dominate all the predicates of Node? bool StructurizeCFG::dominatesPredicates(BasicBlock *BB, RegionNode *Node) { BBPredicates &Preds = Predicates[Node->getEntry()]; - for (BBPredicates::iterator PI = Preds.begin(), PE = Preds.end(); - PI != PE; ++PI) { - - if (!DT->dominates(BB, PI->first)) - return false; - } - return true; + return llvm::all_of(Preds, [&](std::pair<BasicBlock *, Value *> Pred) { + return DT->dominates(BB, Pred.first); + }); } /// \brief Can we predict that this node will always be called? @@ -762,13 +709,14 @@ bool StructurizeCFG::isPredictableTrue(RegionNode *Node) { if (!PrevNode) return true; - for (BBPredicates::iterator I = Preds.begin(), E = Preds.end(); - I != E; ++I) { + for (std::pair<BasicBlock*, Value*> Pred : Preds) { + BasicBlock *BB = Pred.first; + Value *V = Pred.second; - if (I->second != BoolTrue) + if (V != BoolTrue) return false; - if (!Dominated && DT->dominates(I->first, PrevNode->getEntry())) + if (!Dominated && DT->dominates(BB, PrevNode->getEntry())) Dominated = true; } @@ -844,6 +792,7 @@ void StructurizeCFG::handleLoops(bool ExitUseAllowed, LoopFunc, LoopStart); BranchInst::Create(LoopStart, NewEntry); + DT->setNewRoot(NewEntry); } // Create an extra loop end node @@ -883,30 +832,29 @@ void StructurizeCFG::createFlow() { /// no longer dominate all their uses. Not sure if this is really nessasary void StructurizeCFG::rebuildSSA() { SSAUpdater Updater; - for (auto *BB : ParentRegion->blocks()) - for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); - II != IE; ++II) { - + for (BasicBlock *BB : ParentRegion->blocks()) + for (Instruction &I : *BB) { bool Initialized = false; - for (auto I = II->use_begin(), E = II->use_end(); I != E;) { - Use &U = *I++; + // We may modify the use list as we iterate over it, so be careful to + // compute the next element in the use list at the top of the loop. + for (auto UI = I.use_begin(), E = I.use_end(); UI != E;) { + Use &U = *UI++; Instruction *User = cast<Instruction>(U.getUser()); if (User->getParent() == BB) { continue; - } else if (PHINode *UserPN = dyn_cast<PHINode>(User)) { if (UserPN->getIncomingBlock(U) == BB) continue; } - if (DT->dominates(&*II, User)) + if (DT->dominates(&I, User)) continue; if (!Initialized) { - Value *Undef = UndefValue::get(II->getType()); - Updater.Initialize(II->getType(), ""); + Value *Undef = UndefValue::get(I.getType()); + Updater.Initialize(I.getType(), ""); Updater.AddAvailableValue(&Func->getEntryBlock(), Undef); - Updater.AddAvailableValue(BB, &*II); + Updater.AddAvailableValue(BB, &I); Initialized = true; } Updater.RewriteUseAfterInsertions(U); @@ -914,13 +862,14 @@ void StructurizeCFG::rebuildSSA() { } } -bool StructurizeCFG::hasOnlyUniformBranches(const Region *R) { +static bool hasOnlyUniformBranches(const Region *R, + const DivergenceAnalysis &DA) { for (const BasicBlock *BB : R->blocks()) { const BranchInst *Br = dyn_cast<BranchInst>(BB->getTerminator()); if (!Br || !Br->isConditional()) continue; - if (!DA->isUniform(Br->getCondition())) + if (!DA.isUniform(Br->getCondition())) return false; DEBUG(dbgs() << "BB: " << BB->getName() << " has uniform terminator\n"); } @@ -933,9 +882,9 @@ bool StructurizeCFG::runOnRegion(Region *R, RGPassManager &RGM) { return false; if (SkipUniformRegions) { - DA = &getAnalysis<DivergenceAnalysis>(); // TODO: We could probably be smarter here with how we handle sub-regions. - if (hasOnlyUniformBranches(R)) { + auto &DA = getAnalysis<DivergenceAnalysis>(); + if (hasOnlyUniformBranches(R, DA)) { DEBUG(dbgs() << "Skipping region with uniform control flow: " << *R << '\n'); // Mark all direct child block terminators as having been treated as @@ -943,12 +892,11 @@ bool StructurizeCFG::runOnRegion(Region *R, RGPassManager &RGM) { // sub-regions are treated more cleverly, indirect children are not // marked as uniform. MDNode *MD = MDNode::get(R->getEntry()->getParent()->getContext(), {}); - Region::element_iterator E = R->element_end(); - for (Region::element_iterator I = R->element_begin(); I != E; ++I) { - if (I->isSubRegion()) + for (RegionNode *E : R->elements()) { + if (E->isSubRegion()) continue; - if (Instruction *Term = I->getEntry()->getTerminator()) + if (Instruction *Term = E->getEntry()->getTerminator()) Term->setMetadata("structurizecfg.uniform", MD); } diff --git a/contrib/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/contrib/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp index d5ff997..a6b9fee 100644 --- a/contrib/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/contrib/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -236,7 +236,7 @@ static bool markTails(Function &F, bool &AllCallsAreTailCalls) { if (!CI || CI->isTailCall()) continue; - bool IsNoTail = CI->isNoTailCall(); + bool IsNoTail = CI->isNoTailCall() || CI->hasOperandBundles(); if (!IsNoTail && CI->doesNotAccessMemory()) { // A call to a readnone function whose arguments are all things computed @@ -347,7 +347,7 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI) { // return value of the call, it must only use things that are defined before // the call, or movable instructions between the call and the instruction // itself. - return std::find(I->op_begin(), I->op_end(), CI) == I->op_end(); + return !is_contained(I->operands(), CI); } /// Return true if the specified value is the same when the return would exit diff --git a/contrib/llvm/lib/Transforms/Utils/ASanStackFrameLayout.cpp b/contrib/llvm/lib/Transforms/Utils/ASanStackFrameLayout.cpp index 7e50d4b..df9d5da 100644 --- a/contrib/llvm/lib/Transforms/Utils/ASanStackFrameLayout.cpp +++ b/contrib/llvm/lib/Transforms/Utils/ASanStackFrameLayout.cpp @@ -12,7 +12,9 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/ASanStackFrameLayout.h" #include "llvm/ADT/SmallString.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/raw_ostream.h" #include <algorithm> @@ -47,64 +49,102 @@ static size_t VarAndRedzoneSize(size_t Size, size_t Alignment) { return alignTo(Res, Alignment); } -void +ASanStackFrameLayout ComputeASanStackFrameLayout(SmallVectorImpl<ASanStackVariableDescription> &Vars, - size_t Granularity, size_t MinHeaderSize, - ASanStackFrameLayout *Layout) { + size_t Granularity, size_t MinHeaderSize) { assert(Granularity >= 8 && Granularity <= 64 && (Granularity & (Granularity - 1)) == 0); assert(MinHeaderSize >= 16 && (MinHeaderSize & (MinHeaderSize - 1)) == 0 && MinHeaderSize >= Granularity); - size_t NumVars = Vars.size(); + const size_t NumVars = Vars.size(); assert(NumVars > 0); for (size_t i = 0; i < NumVars; i++) Vars[i].Alignment = std::max(Vars[i].Alignment, kMinAlignment); std::stable_sort(Vars.begin(), Vars.end(), CompareVars); - SmallString<2048> StackDescriptionStorage; - raw_svector_ostream StackDescription(StackDescriptionStorage); - StackDescription << NumVars; - Layout->FrameAlignment = std::max(Granularity, Vars[0].Alignment); - SmallVector<uint8_t, 64> &SB(Layout->ShadowBytes); - SB.clear(); + + ASanStackFrameLayout Layout; + Layout.Granularity = Granularity; + Layout.FrameAlignment = std::max(Granularity, Vars[0].Alignment); size_t Offset = std::max(std::max(MinHeaderSize, Granularity), Vars[0].Alignment); assert((Offset % Granularity) == 0); - SB.insert(SB.end(), Offset / Granularity, kAsanStackLeftRedzoneMagic); for (size_t i = 0; i < NumVars; i++) { bool IsLast = i == NumVars - 1; size_t Alignment = std::max(Granularity, Vars[i].Alignment); (void)Alignment; // Used only in asserts. size_t Size = Vars[i].Size; - const char *Name = Vars[i].Name; assert((Alignment & (Alignment - 1)) == 0); - assert(Layout->FrameAlignment >= Alignment); + assert(Layout.FrameAlignment >= Alignment); assert((Offset % Alignment) == 0); assert(Size > 0); - StackDescription << " " << Offset << " " << Size << " " << strlen(Name) - << " " << Name; size_t NextAlignment = IsLast ? Granularity : std::max(Granularity, Vars[i + 1].Alignment); - size_t SizeWithRedzone = VarAndRedzoneSize(Vars[i].Size, NextAlignment); - SB.insert(SB.end(), Size / Granularity, 0); - if (Size % Granularity) - SB.insert(SB.end(), Size % Granularity); - SB.insert(SB.end(), (SizeWithRedzone - Size) / Granularity, - IsLast ? kAsanStackRightRedzoneMagic - : kAsanStackMidRedzoneMagic); + size_t SizeWithRedzone = VarAndRedzoneSize(Size, NextAlignment); Vars[i].Offset = Offset; Offset += SizeWithRedzone; } if (Offset % MinHeaderSize) { - size_t ExtraRedzone = MinHeaderSize - (Offset % MinHeaderSize); - SB.insert(SB.end(), ExtraRedzone / Granularity, - kAsanStackRightRedzoneMagic); - Offset += ExtraRedzone; + Offset += MinHeaderSize - (Offset % MinHeaderSize); + } + Layout.FrameSize = Offset; + assert((Layout.FrameSize % MinHeaderSize) == 0); + return Layout; +} + +SmallString<64> ComputeASanStackFrameDescription( + const SmallVectorImpl<ASanStackVariableDescription> &Vars) { + SmallString<2048> StackDescriptionStorage; + raw_svector_ostream StackDescription(StackDescriptionStorage); + StackDescription << Vars.size(); + + for (const auto &Var : Vars) { + std::string Name = Var.Name; + if (Var.Line) { + Name += ":"; + Name += to_string(Var.Line); + } + StackDescription << " " << Var.Offset << " " << Var.Size << " " + << Name.size() << " " << Name; } - Layout->DescriptionString = StackDescription.str(); - Layout->FrameSize = Offset; - assert((Layout->FrameSize % MinHeaderSize) == 0); - assert(Layout->FrameSize / Granularity == Layout->ShadowBytes.size()); + return StackDescription.str(); +} + +SmallVector<uint8_t, 64> +GetShadowBytes(const SmallVectorImpl<ASanStackVariableDescription> &Vars, + const ASanStackFrameLayout &Layout) { + assert(Vars.size() > 0); + SmallVector<uint8_t, 64> SB; + SB.clear(); + const size_t Granularity = Layout.Granularity; + SB.resize(Vars[0].Offset / Granularity, kAsanStackLeftRedzoneMagic); + for (const auto &Var : Vars) { + SB.resize(Var.Offset / Granularity, kAsanStackMidRedzoneMagic); + + SB.resize(SB.size() + Var.Size / Granularity, 0); + if (Var.Size % Granularity) + SB.push_back(Var.Size % Granularity); + } + SB.resize(Layout.FrameSize / Granularity, kAsanStackRightRedzoneMagic); + return SB; +} + +SmallVector<uint8_t, 64> GetShadowBytesAfterScope( + const SmallVectorImpl<ASanStackVariableDescription> &Vars, + const ASanStackFrameLayout &Layout) { + SmallVector<uint8_t, 64> SB = GetShadowBytes(Vars, Layout); + const size_t Granularity = Layout.Granularity; + + for (const auto &Var : Vars) { + assert(Var.LifetimeSize <= Var.Size); + const size_t LifetimeShadowSize = + (Var.LifetimeSize + Granularity - 1) / Granularity; + const size_t Offset = Var.Offset / Granularity; + std::fill(SB.begin() + Offset, SB.begin() + Offset + LifetimeShadowSize, + kAsanStackUseAfterScopeMagic); + } + + return SB; } } // llvm namespace diff --git a/contrib/llvm/lib/Transforms/Utils/AddDiscriminators.cpp b/contrib/llvm/lib/Transforms/Utils/AddDiscriminators.cpp index d034905..2e95926 100644 --- a/contrib/llvm/lib/Transforms/Utils/AddDiscriminators.cpp +++ b/contrib/llvm/lib/Transforms/Utils/AddDiscriminators.cpp @@ -57,12 +57,10 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" -#include "llvm/IR/DIBuilder.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -159,20 +157,14 @@ static bool addDiscriminators(Function &F) { // If the function has debug information, but the user has disabled // discriminators, do nothing. // Simlarly, if the function has no debug info, do nothing. - // Finally, if this module is built with dwarf versions earlier than 4, - // do nothing (discriminator support is a DWARF 4 feature). - if (NoDiscriminators || !F.getSubprogram() || - F.getParent()->getDwarfVersion() < 4) + if (NoDiscriminators || !F.getSubprogram()) return false; bool Changed = false; - Module *M = F.getParent(); - LLVMContext &Ctx = M->getContext(); - DIBuilder Builder(*M, /*AllowUnresolved*/ false); typedef std::pair<StringRef, unsigned> Location; - typedef DenseMap<const BasicBlock *, Metadata *> BBScopeMap; - typedef DenseMap<Location, BBScopeMap> LocationBBMap; + typedef DenseSet<const BasicBlock *> BBSet; + typedef DenseMap<Location, BBSet> LocationBBMap; typedef DenseMap<Location, unsigned> LocationDiscriminatorMap; typedef DenseSet<Location> LocationSet; @@ -184,32 +176,25 @@ static bool addDiscriminators(Function &F) { // discriminator for this instruction. for (BasicBlock &B : F) { for (auto &I : B.getInstList()) { - if (isa<DbgInfoIntrinsic>(&I)) + if (isa<IntrinsicInst>(&I)) continue; const DILocation *DIL = I.getDebugLoc(); if (!DIL) continue; Location L = std::make_pair(DIL->getFilename(), DIL->getLine()); auto &BBMap = LBM[L]; - auto R = BBMap.insert(std::make_pair(&B, (Metadata *)nullptr)); + auto R = BBMap.insert(&B); if (BBMap.size() == 1) continue; - bool InsertSuccess = R.second; - Metadata *&NewScope = R.first->second; - // If we could insert a different block in the same location, a + // If we could insert more than one block with the same line+file, a // discriminator is needed to distinguish both instructions. - if (InsertSuccess) { - auto *Scope = DIL->getScope(); - auto *File = - Builder.createFile(DIL->getFilename(), Scope->getDirectory()); - NewScope = Builder.createLexicalBlockFile(Scope, File, ++LDM[L]); - } - I.setDebugLoc(DILocation::get(Ctx, DIL->getLine(), DIL->getColumn(), - NewScope, DIL->getInlinedAt())); + // Only the lowest 7 bits are used to represent a discriminator to fit + // it in 1 byte ULEB128 representation. + unsigned Discriminator = (R.second ? ++LDM[L] : LDM[L]) & 0x7f; + I.setDebugLoc(DIL->cloneWithDiscriminator(Discriminator)); DEBUG(dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":" - << DIL->getColumn() << ":" - << dyn_cast<DILexicalBlockFile>(NewScope)->getDiscriminator() - << I << "\n"); + << DIL->getColumn() << ":" << Discriminator << " " << I + << "\n"); Changed = true; } } @@ -222,7 +207,7 @@ static bool addDiscriminators(Function &F) { LocationSet CallLocations; for (auto &I : B.getInstList()) { CallInst *Current = dyn_cast<CallInst>(&I); - if (!Current || isa<DbgInfoIntrinsic>(&I)) + if (!Current || isa<IntrinsicInst>(&I)) continue; DILocation *CurrentDIL = Current->getDebugLoc(); @@ -231,13 +216,8 @@ static bool addDiscriminators(Function &F) { Location L = std::make_pair(CurrentDIL->getFilename(), CurrentDIL->getLine()); if (!CallLocations.insert(L).second) { - auto *Scope = CurrentDIL->getScope(); - auto *File = Builder.createFile(CurrentDIL->getFilename(), - Scope->getDirectory()); - auto *NewScope = Builder.createLexicalBlockFile(Scope, File, ++LDM[L]); - Current->setDebugLoc(DILocation::get(Ctx, CurrentDIL->getLine(), - CurrentDIL->getColumn(), NewScope, - CurrentDIL->getInlinedAt())); + Current->setDebugLoc( + CurrentDIL->cloneWithDiscriminator((++LDM[L]) & 0x7f)); Changed = true; } } @@ -249,7 +229,7 @@ bool AddDiscriminatorsLegacyPass::runOnFunction(Function &F) { return addDiscriminators(F); } PreservedAnalyses AddDiscriminatorsPass::run(Function &F, - AnalysisManager<Function> &AM) { + FunctionAnalysisManager &AM) { if (!addDiscriminators(F)) return PreservedAnalyses::all(); diff --git a/contrib/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp b/contrib/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp index 49b646a..175cbd2 100644 --- a/contrib/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp +++ b/contrib/llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -15,7 +15,7 @@ // //===----------------------------------------------------------------------===// -#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BreakCriticalEdges.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -23,10 +23,10 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Type.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; @@ -72,6 +72,20 @@ FunctionPass *llvm::createBreakCriticalEdgesPass() { return new BreakCriticalEdges(); } +PreservedAnalyses BreakCriticalEdgesPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); + auto *LI = AM.getCachedResult<LoopAnalysis>(F); + unsigned N = SplitAllCriticalEdges(F, CriticalEdgeSplittingOptions(DT, LI)); + NumBroken += N; + if (N == 0) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<LoopAnalysis>(); + return PA; +} + //===----------------------------------------------------------------------===// // Implementation of the external critical edge manipulation functions //===----------------------------------------------------------------------===// diff --git a/contrib/llvm/lib/Transforms/Utils/BuildLibCalls.cpp b/contrib/llvm/lib/Transforms/Utils/BuildLibCalls.cpp index f4260a9..e61b04f 100644 --- a/contrib/llvm/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/contrib/llvm/lib/Transforms/Utils/BuildLibCalls.cpp @@ -250,6 +250,7 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { Changed |= setDoesNotCapture(F, 2); return Changed; case LibFunc::memcpy: + case LibFunc::mempcpy: case LibFunc::memccpy: case LibFunc::memmove: Changed |= setDoesNotThrow(F); diff --git a/contrib/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp b/contrib/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp index 42287d3..bc2cef2 100644 --- a/contrib/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp +++ b/contrib/llvm/lib/Transforms/Utils/BypassSlowDivision.cpp @@ -20,6 +20,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" +#include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -82,13 +83,17 @@ static bool insertFastDiv(Instruction *I, IntegerType *BypassType, Value *Dividend = I->getOperand(0); Value *Divisor = I->getOperand(1); - if (isa<ConstantInt>(Divisor) || - (isa<ConstantInt>(Dividend) && isa<ConstantInt>(Divisor))) { - // Operations with immediate values should have - // been solved and replaced during compile time. + if (isa<ConstantInt>(Divisor)) { + // Division by a constant should have been been solved and replaced earlier + // in the pipeline. return false; } + // If the numerator is a constant, bail if it doesn't fit into BypassType. + if (ConstantInt *ConstDividend = dyn_cast<ConstantInt>(Dividend)) + if (ConstDividend->getValue().getActiveBits() > BypassType->getBitWidth()) + return false; + // Basic Block is split before divide BasicBlock *MainBB = &*I->getParent(); BasicBlock *SuccessorBB = MainBB->splitBasicBlock(I); @@ -120,8 +125,7 @@ static bool insertFastDiv(Instruction *I, IntegerType *BypassType, BypassType); // udiv/urem because optimization only handles positive numbers - Value *ShortQuotientV = FastBuilder.CreateExactUDiv(ShortDividendV, - ShortDivisorV); + Value *ShortQuotientV = FastBuilder.CreateUDiv(ShortDividendV, ShortDivisorV); Value *ShortRemainderV = FastBuilder.CreateURem(ShortDividendV, ShortDivisorV); Value *FastQuotientV = FastBuilder.CreateCast(Instruction::ZExt, @@ -151,7 +155,17 @@ static bool insertFastDiv(Instruction *I, IntegerType *BypassType, // Combine operands into a single value with OR for value testing below MainBB->getInstList().back().eraseFromParent(); IRBuilder<> MainBuilder(MainBB, MainBB->end()); - Value *OrV = MainBuilder.CreateOr(Dividend, Divisor); + + // We should have bailed out above if the divisor is a constant, but the + // dividend may still be a constant. Set OrV to our non-constant operands + // OR'ed together. + assert(!isa<ConstantInt>(Divisor)); + + Value *OrV; + if (!isa<ConstantInt>(Dividend)) + OrV = MainBuilder.CreateOr(Dividend, Divisor); + else + OrV = Divisor; // BitMask is inverted to check if the operands are // larger than the bypass type @@ -247,5 +261,12 @@ bool llvm::bypassSlowDivision( MadeChange |= reuseOrInsertFastDiv(I, BT, UseDivOp, UseSignedOp, DivCache); } + // Above we eagerly create divs and rems, as pairs, so that we can efficiently + // create divrem machine instructions. Now erase any unused divs / rems so we + // don't leave extra instructions sitting around. + for (auto &KV : DivCache) + for (Instruction *Phi : {KV.second.Quotient, KV.second.Remainder}) + RecursivelyDeleteTriviallyDeadInstructions(Phi); + return MadeChange; } diff --git a/contrib/llvm/lib/Transforms/Utils/CloneModule.cpp b/contrib/llvm/lib/Transforms/Utils/CloneModule.cpp index 17e34c4..7ebeb61 100644 --- a/contrib/llvm/lib/Transforms/Utils/CloneModule.cpp +++ b/contrib/llvm/lib/Transforms/Utils/CloneModule.cpp @@ -119,6 +119,11 @@ std::unique_ptr<Module> llvm::CloneModule( } if (I->hasInitializer()) GV->setInitializer(MapValue(I->getInitializer(), VMap)); + + SmallVector<std::pair<unsigned, MDNode *>, 1> MDs; + I->getAllMetadata(MDs); + for (auto MD : MDs) + GV->addMetadata(MD.first, *MapMetadata(MD.second, VMap)); } // Similarly, copy over function bodies now... diff --git a/contrib/llvm/lib/Transforms/Utils/CmpInstAnalysis.cpp b/contrib/llvm/lib/Transforms/Utils/CmpInstAnalysis.cpp index 3b15a0a..60ae374 100644 --- a/contrib/llvm/lib/Transforms/Utils/CmpInstAnalysis.cpp +++ b/contrib/llvm/lib/Transforms/Utils/CmpInstAnalysis.cpp @@ -18,29 +18,6 @@ using namespace llvm; -/// getICmpCode - Encode a icmp predicate into a three bit mask. These bits -/// are carefully arranged to allow folding of expressions such as: -/// -/// (A < B) | (A > B) --> (A != B) -/// -/// Note that this is only valid if the first and second predicates have the -/// same sign. Is illegal to do: (A u< B) | (A s> B) -/// -/// Three bits are used to represent the condition, as follows: -/// 0 A > B -/// 1 A == B -/// 2 A < B -/// -/// <=> Value Definition -/// 000 0 Always false -/// 001 1 A > B -/// 010 2 A == B -/// 011 3 A >= B -/// 100 4 A < B -/// 101 5 A != B -/// 110 6 A <= B -/// 111 7 Always true -/// unsigned llvm::getICmpCode(const ICmpInst *ICI, bool InvertPred) { ICmpInst::Predicate Pred = InvertPred ? ICI->getInversePredicate() : ICI->getPredicate(); @@ -62,13 +39,6 @@ unsigned llvm::getICmpCode(const ICmpInst *ICI, bool InvertPred) { } } -/// getICmpValue - This is the complement of getICmpCode, which turns an -/// opcode and two operands into either a constant true or false, or the -/// predicate for a new ICmp instruction. The sign is passed in to determine -/// which kind of predicate to use in the new icmp instruction. -/// Non-NULL return value will be a true or false constant. -/// NULL return means a new ICmp is needed. The predicate for which is -/// output in NewICmpPred. Value *llvm::getICmpValue(bool Sign, unsigned Code, Value *LHS, Value *RHS, CmpInst::Predicate &NewICmpPred) { switch (Code) { @@ -87,10 +57,52 @@ Value *llvm::getICmpValue(bool Sign, unsigned Code, Value *LHS, Value *RHS, return nullptr; } -/// PredicatesFoldable - Return true if both predicates match sign or if at -/// least one of them is an equality comparison (which is signless). bool llvm::PredicatesFoldable(ICmpInst::Predicate p1, ICmpInst::Predicate p2) { return (CmpInst::isSigned(p1) == CmpInst::isSigned(p2)) || (CmpInst::isSigned(p1) && ICmpInst::isEquality(p2)) || (CmpInst::isSigned(p2) && ICmpInst::isEquality(p1)); } + +bool llvm::decomposeBitTestICmp(const ICmpInst *I, CmpInst::Predicate &Pred, + Value *&X, Value *&Y, Value *&Z) { + ConstantInt *C = dyn_cast<ConstantInt>(I->getOperand(1)); + if (!C) + return false; + + switch (I->getPredicate()) { + default: + return false; + case ICmpInst::ICMP_SLT: + // X < 0 is equivalent to (X & SignBit) != 0. + if (!C->isZero()) + return false; + Y = ConstantInt::get(I->getContext(), APInt::getSignBit(C->getBitWidth())); + Pred = ICmpInst::ICMP_NE; + break; + case ICmpInst::ICMP_SGT: + // X > -1 is equivalent to (X & SignBit) == 0. + if (!C->isAllOnesValue()) + return false; + Y = ConstantInt::get(I->getContext(), APInt::getSignBit(C->getBitWidth())); + Pred = ICmpInst::ICMP_EQ; + break; + case ICmpInst::ICMP_ULT: + // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0. + if (!C->getValue().isPowerOf2()) + return false; + Y = ConstantInt::get(I->getContext(), -C->getValue()); + Pred = ICmpInst::ICMP_EQ; + break; + case ICmpInst::ICMP_UGT: + // X >u 2^n-1 is equivalent to (X & ~(2^n-1)) != 0. + if (!(C->getValue() + 1).isPowerOf2()) + return false; + Y = ConstantInt::get(I->getContext(), ~C->getValue()); + Pred = ICmpInst::ICMP_NE; + break; + } + + X = I->getOperand(0); + Z = ConstantInt::getNullValue(C->getType()); + return true; +} diff --git a/contrib/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/contrib/llvm/lib/Transforms/Utils/CodeExtractor.cpp index 9f2181f..c514c9c 100644 --- a/contrib/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/contrib/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -17,6 +17,9 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/RegionInfo.h" #include "llvm/Analysis/RegionIterator.h" @@ -26,9 +29,11 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/Pass.h" +#include "llvm/Support/BlockFrequency.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" @@ -49,7 +54,7 @@ AggregateArgsOpt("aggregate-extracted-args", cl::Hidden, cl::desc("Aggregate arguments to code-extracted functions")); /// \brief Test whether a block is valid for extraction. -static bool isBlockValidForExtraction(const BasicBlock &BB) { +bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) { // Landing pads must be in the function where they were inserted for cleanup. if (BB.isEHPad()) return false; @@ -81,7 +86,7 @@ static SetVector<BasicBlock *> buildExtractionBlockSet(IteratorT BBBegin, if (!Result.insert(*BBBegin)) llvm_unreachable("Repeated basic blocks in extraction input"); - if (!isBlockValidForExtraction(**BBBegin)) { + if (!CodeExtractor::isBlockValidForExtraction(**BBBegin)) { Result.clear(); return Result; } @@ -119,23 +124,30 @@ buildExtractionBlockSet(const RegionNode &RN) { return buildExtractionBlockSet(R.block_begin(), R.block_end()); } -CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs) - : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} +CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs, + BlockFrequencyInfo *BFI, + BranchProbabilityInfo *BPI) + : DT(nullptr), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), + BPI(BPI), Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, - bool AggregateArgs) - : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} - -CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs) - : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {} + bool AggregateArgs, BlockFrequencyInfo *BFI, + BranchProbabilityInfo *BPI) + : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), + BPI(BPI), Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} + +CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, + BlockFrequencyInfo *BFI, + BranchProbabilityInfo *BPI) + : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), + BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks())), + NumExitBlocks(~0U) {} CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN, - bool AggregateArgs) - : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), - Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {} + bool AggregateArgs, BlockFrequencyInfo *BFI, + BranchProbabilityInfo *BPI) + : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), + BPI(BPI), Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {} /// definedInRegion - Return true if the specified value is defined in the /// extracted region. @@ -339,7 +351,22 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, // If the old function is no-throw, so is the new one. if (oldFunction->doesNotThrow()) newFunction->setDoesNotThrow(); - + + // Inherit the uwtable attribute if we need to. + if (oldFunction->hasUWTable()) + newFunction->setHasUWTable(); + + // Inherit all of the target dependent attributes. + // (e.g. If the extracted region contains a call to an x86.sse + // instruction we need to make sure that the extracted region has the + // "target-features" attribute allowing it to be lowered. + // FIXME: This should be changed to check to see if a specific + // attribute can not be inherited. + AttributeSet OldFnAttrs = oldFunction->getAttributes().getFnAttributes(); + AttrBuilder AB(OldFnAttrs, AttributeSet::FunctionIndex); + for (auto Attr : AB.td_attrs()) + newFunction->addFnAttr(Attr.first, Attr.second); + newFunction->getBasicBlockList().push_back(newRootNode); // Create an iterator to name all of the arguments we inserted. @@ -672,6 +699,51 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) { } } +void CodeExtractor::calculateNewCallTerminatorWeights( + BasicBlock *CodeReplacer, + DenseMap<BasicBlock *, BlockFrequency> &ExitWeights, + BranchProbabilityInfo *BPI) { + typedef BlockFrequencyInfoImplBase::Distribution Distribution; + typedef BlockFrequencyInfoImplBase::BlockNode BlockNode; + + // Update the branch weights for the exit block. + TerminatorInst *TI = CodeReplacer->getTerminator(); + SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0); + + // Block Frequency distribution with dummy node. + Distribution BranchDist; + + // Add each of the frequencies of the successors. + for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) { + BlockNode ExitNode(i); + uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency(); + if (ExitFreq != 0) + BranchDist.addExit(ExitNode, ExitFreq); + else + BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero()); + } + + // Check for no total weight. + if (BranchDist.Total == 0) + return; + + // Normalize the distribution so that they can fit in unsigned. + BranchDist.normalize(); + + // Create normalized branch weights and set the metadata. + for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) { + const auto &Weight = BranchDist.Weights[I]; + + // Get the weight and update the current BFI. + BranchWeights[Weight.TargetNode.Index] = Weight.Amount; + BranchProbability BP(Weight.Amount, BranchDist.Total); + BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP); + } + TI->setMetadata( + LLVMContext::MD_prof, + MDBuilder(TI->getContext()).createBranchWeights(BranchWeights)); +} + Function *CodeExtractor::extractCodeRegion() { if (!isEligible()) return nullptr; @@ -682,6 +754,19 @@ Function *CodeExtractor::extractCodeRegion() { // block in the region. BasicBlock *header = *Blocks.begin(); + // Calculate the entry frequency of the new function before we change the root + // block. + BlockFrequency EntryFreq; + if (BFI) { + assert(BPI && "Both BPI and BFI are required to preserve profile info"); + for (BasicBlock *Pred : predecessors(header)) { + if (Blocks.count(Pred)) + continue; + EntryFreq += + BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header); + } + } + // If we have to split PHI nodes or the entry block, do so now. severSplitPHINodes(header); @@ -705,12 +790,23 @@ Function *CodeExtractor::extractCodeRegion() { // Find inputs to, outputs from the code region. findInputsOutputs(inputs, outputs); + // Calculate the exit blocks for the extracted region and the total exit + // weights for each of those blocks. + DenseMap<BasicBlock *, BlockFrequency> ExitWeights; SmallPtrSet<BasicBlock *, 1> ExitBlocks; - for (BasicBlock *Block : Blocks) + for (BasicBlock *Block : Blocks) { for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE; - ++SI) - if (!Blocks.count(*SI)) + ++SI) { + if (!Blocks.count(*SI)) { + // Update the branch weight for this successor. + if (BFI) { + BlockFrequency &BF = ExitWeights[*SI]; + BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI); + } ExitBlocks.insert(*SI); + } + } + } NumExitBlocks = ExitBlocks.size(); // Construct new function based on inputs/outputs & add allocas for all defs. @@ -719,10 +815,23 @@ Function *CodeExtractor::extractCodeRegion() { codeReplacer, oldFunction, oldFunction->getParent()); + // Update the entry count of the function. + if (BFI) { + Optional<uint64_t> EntryCount = + BFI->getProfileCountFromFreq(EntryFreq.getFrequency()); + if (EntryCount.hasValue()) + newFunction->setEntryCount(EntryCount.getValue()); + BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency()); + } + emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); moveCodeToFunction(newFunction); + // Update the branch weights for the exit block. + if (BFI && NumExitBlocks > 1) + calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI); + // Loop over all of the PHI nodes in the header block, and change any // references to the old incoming edge to be the new incoming edge. for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) { diff --git a/contrib/llvm/lib/Transforms/Utils/CtorUtils.cpp b/contrib/llvm/lib/Transforms/Utils/CtorUtils.cpp index b56ff68..6642a97 100644 --- a/contrib/llvm/lib/Transforms/Utils/CtorUtils.cpp +++ b/contrib/llvm/lib/Transforms/Utils/CtorUtils.cpp @@ -71,8 +71,8 @@ std::vector<Function *> parseGlobalCtors(GlobalVariable *GV) { ConstantArray *CA = cast<ConstantArray>(GV->getInitializer()); std::vector<Function *> Result; Result.reserve(CA->getNumOperands()); - for (User::op_iterator i = CA->op_begin(), e = CA->op_end(); i != e; ++i) { - ConstantStruct *CS = cast<ConstantStruct>(*i); + for (auto &V : CA->operands()) { + ConstantStruct *CS = cast<ConstantStruct>(V); Result.push_back(dyn_cast<Function>(CS->getOperand(1))); } return Result; @@ -94,10 +94,10 @@ GlobalVariable *findGlobalCtors(Module &M) { return GV; ConstantArray *CA = cast<ConstantArray>(GV->getInitializer()); - for (User::op_iterator i = CA->op_begin(), e = CA->op_end(); i != e; ++i) { - if (isa<ConstantAggregateZero>(*i)) + for (auto &V : CA->operands()) { + if (isa<ConstantAggregateZero>(V)) continue; - ConstantStruct *CS = cast<ConstantStruct>(*i); + ConstantStruct *CS = cast<ConstantStruct>(V); if (isa<ConstantPointerNull>(CS->getOperand(1))) continue; diff --git a/contrib/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp b/contrib/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp new file mode 100644 index 0000000..8c23865 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Utils/EscapeEnumerator.cpp @@ -0,0 +1,96 @@ +//===- EscapeEnumerator.cpp -----------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Defines a helper class that enumerates all possible exits from a function, +// including exception handling. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/EscapeEnumerator.h" +#include "llvm/Analysis/EHPersonalities.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Module.h" +#include "llvm/Transforms/Utils/Local.h" +using namespace llvm; + +static Constant *getDefaultPersonalityFn(Module *M) { + LLVMContext &C = M->getContext(); + Triple T(M->getTargetTriple()); + EHPersonality Pers = getDefaultEHPersonality(T); + return M->getOrInsertFunction(getEHPersonalityName(Pers), + FunctionType::get(Type::getInt32Ty(C), true)); +} + +IRBuilder<> *EscapeEnumerator::Next() { + if (Done) + return nullptr; + + // Find all 'return', 'resume', and 'unwind' instructions. + while (StateBB != StateE) { + BasicBlock *CurBB = &*StateBB++; + + // Branches and invokes do not escape, only unwind, resume, and return + // do. + TerminatorInst *TI = CurBB->getTerminator(); + if (!isa<ReturnInst>(TI) && !isa<ResumeInst>(TI)) + continue; + + Builder.SetInsertPoint(TI); + return &Builder; + } + + Done = true; + + if (!HandleExceptions) + return nullptr; + + if (F.doesNotThrow()) + return nullptr; + + // Find all 'call' instructions that may throw. + SmallVector<Instruction *, 16> Calls; + for (BasicBlock &BB : F) + for (Instruction &II : BB) + if (CallInst *CI = dyn_cast<CallInst>(&II)) + if (!CI->doesNotThrow()) + Calls.push_back(CI); + + if (Calls.empty()) + return nullptr; + + // Create a cleanup block. + LLVMContext &C = F.getContext(); + BasicBlock *CleanupBB = BasicBlock::Create(C, CleanupBBName, &F); + Type *ExnTy = + StructType::get(Type::getInt8PtrTy(C), Type::getInt32Ty(C), nullptr); + if (!F.hasPersonalityFn()) { + Constant *PersFn = getDefaultPersonalityFn(F.getParent()); + F.setPersonalityFn(PersFn); + } + + if (isFuncletEHPersonality(classifyEHPersonality(F.getPersonalityFn()))) { + report_fatal_error("Funclet EH not supported"); + } + + LandingPadInst *LPad = + LandingPadInst::Create(ExnTy, 1, "cleanup.lpad", CleanupBB); + LPad->setCleanup(true); + ResumeInst *RI = ResumeInst::Create(LPad, CleanupBB); + + // Transform the 'call' instructions into 'invoke's branching to the + // cleanup block. Go in reverse order to make prettier BB names. + SmallVector<Value *, 16> Args; + for (unsigned I = Calls.size(); I != 0;) { + CallInst *CI = cast<CallInst>(Calls[--I]); + changeToInvokeAndSplitBasicBlock(CI, CleanupBB); + } + + Builder.SetInsertPoint(RI); + return &Builder; +} diff --git a/contrib/llvm/lib/Transforms/Utils/Evaluator.cpp b/contrib/llvm/lib/Transforms/Utils/Evaluator.cpp index cd130ab..4adf175 100644 --- a/contrib/llvm/lib/Transforms/Utils/Evaluator.cpp +++ b/contrib/llvm/lib/Transforms/Utils/Evaluator.cpp @@ -203,9 +203,9 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, return false; // no volatile/atomic accesses. } Constant *Ptr = getVal(SI->getOperand(1)); - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) { + if (auto *FoldedPtr = ConstantFoldConstant(Ptr, DL, TLI)) { DEBUG(dbgs() << "Folding constant ptr expression: " << *Ptr); - Ptr = ConstantFoldConstantExpression(CE, DL, TLI); + Ptr = FoldedPtr; DEBUG(dbgs() << "; To: " << *Ptr << "\n"); } if (!isSimpleEnoughPointerToCommit(Ptr)) { @@ -249,8 +249,8 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, Constant * const IdxList[] = {IdxZero, IdxZero}; Ptr = ConstantExpr::getGetElementPtr(nullptr, Ptr, IdxList); - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) - Ptr = ConstantFoldConstantExpression(CE, DL, TLI); + if (auto *FoldedPtr = ConstantFoldConstant(Ptr, DL, TLI)) + Ptr = FoldedPtr; // If we can't improve the situation by introspecting NewTy, // we have to give up. @@ -324,8 +324,8 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, } Constant *Ptr = getVal(LI->getOperand(0)); - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) { - Ptr = ConstantFoldConstantExpression(CE, DL, TLI); + if (auto *FoldedPtr = ConstantFoldConstant(Ptr, DL, TLI)) { + Ptr = FoldedPtr; DEBUG(dbgs() << "Found a constant pointer expression, constant " "folding: " << *Ptr << "\n"); } @@ -512,8 +512,8 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, } if (!CurInst->use_empty()) { - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(InstResult)) - InstResult = ConstantFoldConstantExpression(CE, DL, TLI); + if (auto *FoldedInstResult = ConstantFoldConstant(InstResult, DL, TLI)) + InstResult = FoldedInstResult; setVal(&*CurInst, InstResult); } @@ -537,7 +537,7 @@ bool Evaluator::EvaluateFunction(Function *F, Constant *&RetVal, const SmallVectorImpl<Constant*> &ActualArgs) { // Check to see if this function is already executing (recursion). If so, // bail out. TODO: we might want to accept limited recursion. - if (std::find(CallStack.begin(), CallStack.end(), F) != CallStack.end()) + if (is_contained(CallStack, F)) return false; CallStack.push_back(F); diff --git a/contrib/llvm/lib/Transforms/Utils/FlattenCFG.cpp b/contrib/llvm/lib/Transforms/Utils/FlattenCFG.cpp index 492ae9f..7b96fbb 100644 --- a/contrib/llvm/lib/Transforms/Utils/FlattenCFG.cpp +++ b/contrib/llvm/lib/Transforms/Utils/FlattenCFG.cpp @@ -463,19 +463,14 @@ bool FlattenCFGOpt::MergeIfRegion(BasicBlock *BB, IRBuilder<> &Builder) { } bool FlattenCFGOpt::run(BasicBlock *BB) { - bool Changed = false; assert(BB && BB->getParent() && "Block not embedded in function!"); assert(BB->getTerminator() && "Degenerate basic block encountered!"); IRBuilder<> Builder(BB); - if (FlattenParallelAndOr(BB, Builder)) + if (FlattenParallelAndOr(BB, Builder) || MergeIfRegion(BB, Builder)) return true; - - if (MergeIfRegion(BB, Builder)) - return true; - - return Changed; + return false; } /// FlattenCFG - This function is used to flatten a CFG. For diff --git a/contrib/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/contrib/llvm/lib/Transforms/Utils/FunctionComparator.cpp new file mode 100644 index 0000000..81a7c4c --- /dev/null +++ b/contrib/llvm/lib/Transforms/Utils/FunctionComparator.cpp @@ -0,0 +1,919 @@ +//===- FunctionComparator.h - Function Comparator -------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the FunctionComparator and GlobalNumberState classes +// which are used by the MergeFunctions pass for comparing functions. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/FunctionComparator.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/InlineAsm.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "functioncomparator" + +int FunctionComparator::cmpNumbers(uint64_t L, uint64_t R) const { + if (L < R) return -1; + if (L > R) return 1; + return 0; +} + +int FunctionComparator::cmpOrderings(AtomicOrdering L, AtomicOrdering R) const { + if ((int)L < (int)R) return -1; + if ((int)L > (int)R) return 1; + return 0; +} + +int FunctionComparator::cmpAPInts(const APInt &L, const APInt &R) const { + if (int Res = cmpNumbers(L.getBitWidth(), R.getBitWidth())) + return Res; + if (L.ugt(R)) return 1; + if (R.ugt(L)) return -1; + return 0; +} + +int FunctionComparator::cmpAPFloats(const APFloat &L, const APFloat &R) const { + // Floats are ordered first by semantics (i.e. float, double, half, etc.), + // then by value interpreted as a bitstring (aka APInt). + const fltSemantics &SL = L.getSemantics(), &SR = R.getSemantics(); + if (int Res = cmpNumbers(APFloat::semanticsPrecision(SL), + APFloat::semanticsPrecision(SR))) + return Res; + if (int Res = cmpNumbers(APFloat::semanticsMaxExponent(SL), + APFloat::semanticsMaxExponent(SR))) + return Res; + if (int Res = cmpNumbers(APFloat::semanticsMinExponent(SL), + APFloat::semanticsMinExponent(SR))) + return Res; + if (int Res = cmpNumbers(APFloat::semanticsSizeInBits(SL), + APFloat::semanticsSizeInBits(SR))) + return Res; + return cmpAPInts(L.bitcastToAPInt(), R.bitcastToAPInt()); +} + +int FunctionComparator::cmpMem(StringRef L, StringRef R) const { + // Prevent heavy comparison, compare sizes first. + if (int Res = cmpNumbers(L.size(), R.size())) + return Res; + + // Compare strings lexicographically only when it is necessary: only when + // strings are equal in size. + return L.compare(R); +} + +int FunctionComparator::cmpAttrs(const AttributeSet L, + const AttributeSet R) const { + if (int Res = cmpNumbers(L.getNumSlots(), R.getNumSlots())) + return Res; + + for (unsigned i = 0, e = L.getNumSlots(); i != e; ++i) { + AttributeSet::iterator LI = L.begin(i), LE = L.end(i), RI = R.begin(i), + RE = R.end(i); + for (; LI != LE && RI != RE; ++LI, ++RI) { + Attribute LA = *LI; + Attribute RA = *RI; + if (LA < RA) + return -1; + if (RA < LA) + return 1; + } + if (LI != LE) + return 1; + if (RI != RE) + return -1; + } + return 0; +} + +int FunctionComparator::cmpRangeMetadata(const MDNode *L, + const MDNode *R) const { + if (L == R) + return 0; + if (!L) + return -1; + if (!R) + return 1; + // Range metadata is a sequence of numbers. Make sure they are the same + // sequence. + // TODO: Note that as this is metadata, it is possible to drop and/or merge + // this data when considering functions to merge. Thus this comparison would + // return 0 (i.e. equivalent), but merging would become more complicated + // because the ranges would need to be unioned. It is not likely that + // functions differ ONLY in this metadata if they are actually the same + // function semantically. + if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands())) + return Res; + for (size_t I = 0; I < L->getNumOperands(); ++I) { + ConstantInt *LLow = mdconst::extract<ConstantInt>(L->getOperand(I)); + ConstantInt *RLow = mdconst::extract<ConstantInt>(R->getOperand(I)); + if (int Res = cmpAPInts(LLow->getValue(), RLow->getValue())) + return Res; + } + return 0; +} + +int FunctionComparator::cmpOperandBundlesSchema(const Instruction *L, + const Instruction *R) const { + ImmutableCallSite LCS(L); + ImmutableCallSite RCS(R); + + assert(LCS && RCS && "Must be calls or invokes!"); + assert(LCS.isCall() == RCS.isCall() && "Can't compare otherwise!"); + + if (int Res = + cmpNumbers(LCS.getNumOperandBundles(), RCS.getNumOperandBundles())) + return Res; + + for (unsigned i = 0, e = LCS.getNumOperandBundles(); i != e; ++i) { + auto OBL = LCS.getOperandBundleAt(i); + auto OBR = RCS.getOperandBundleAt(i); + + if (int Res = OBL.getTagName().compare(OBR.getTagName())) + return Res; + + if (int Res = cmpNumbers(OBL.Inputs.size(), OBR.Inputs.size())) + return Res; + } + + return 0; +} + +/// Constants comparison: +/// 1. Check whether type of L constant could be losslessly bitcasted to R +/// type. +/// 2. Compare constant contents. +/// For more details see declaration comments. +int FunctionComparator::cmpConstants(const Constant *L, + const Constant *R) const { + + Type *TyL = L->getType(); + Type *TyR = R->getType(); + + // Check whether types are bitcastable. This part is just re-factored + // Type::canLosslesslyBitCastTo method, but instead of returning true/false, + // we also pack into result which type is "less" for us. + int TypesRes = cmpTypes(TyL, TyR); + if (TypesRes != 0) { + // Types are different, but check whether we can bitcast them. + if (!TyL->isFirstClassType()) { + if (TyR->isFirstClassType()) + return -1; + // Neither TyL nor TyR are values of first class type. Return the result + // of comparing the types + return TypesRes; + } + if (!TyR->isFirstClassType()) { + if (TyL->isFirstClassType()) + return 1; + return TypesRes; + } + + // Vector -> Vector conversions are always lossless if the two vector types + // have the same size, otherwise not. + unsigned TyLWidth = 0; + unsigned TyRWidth = 0; + + if (auto *VecTyL = dyn_cast<VectorType>(TyL)) + TyLWidth = VecTyL->getBitWidth(); + if (auto *VecTyR = dyn_cast<VectorType>(TyR)) + TyRWidth = VecTyR->getBitWidth(); + + if (TyLWidth != TyRWidth) + return cmpNumbers(TyLWidth, TyRWidth); + + // Zero bit-width means neither TyL nor TyR are vectors. + if (!TyLWidth) { + PointerType *PTyL = dyn_cast<PointerType>(TyL); + PointerType *PTyR = dyn_cast<PointerType>(TyR); + if (PTyL && PTyR) { + unsigned AddrSpaceL = PTyL->getAddressSpace(); + unsigned AddrSpaceR = PTyR->getAddressSpace(); + if (int Res = cmpNumbers(AddrSpaceL, AddrSpaceR)) + return Res; + } + if (PTyL) + return 1; + if (PTyR) + return -1; + + // TyL and TyR aren't vectors, nor pointers. We don't know how to + // bitcast them. + return TypesRes; + } + } + + // OK, types are bitcastable, now check constant contents. + + if (L->isNullValue() && R->isNullValue()) + return TypesRes; + if (L->isNullValue() && !R->isNullValue()) + return 1; + if (!L->isNullValue() && R->isNullValue()) + return -1; + + auto GlobalValueL = const_cast<GlobalValue*>(dyn_cast<GlobalValue>(L)); + auto GlobalValueR = const_cast<GlobalValue*>(dyn_cast<GlobalValue>(R)); + if (GlobalValueL && GlobalValueR) { + return cmpGlobalValues(GlobalValueL, GlobalValueR); + } + + if (int Res = cmpNumbers(L->getValueID(), R->getValueID())) + return Res; + + if (const auto *SeqL = dyn_cast<ConstantDataSequential>(L)) { + const auto *SeqR = cast<ConstantDataSequential>(R); + // This handles ConstantDataArray and ConstantDataVector. Note that we + // compare the two raw data arrays, which might differ depending on the host + // endianness. This isn't a problem though, because the endiness of a module + // will affect the order of the constants, but this order is the same + // for a given input module and host platform. + return cmpMem(SeqL->getRawDataValues(), SeqR->getRawDataValues()); + } + + switch (L->getValueID()) { + case Value::UndefValueVal: + case Value::ConstantTokenNoneVal: + return TypesRes; + case Value::ConstantIntVal: { + const APInt &LInt = cast<ConstantInt>(L)->getValue(); + const APInt &RInt = cast<ConstantInt>(R)->getValue(); + return cmpAPInts(LInt, RInt); + } + case Value::ConstantFPVal: { + const APFloat &LAPF = cast<ConstantFP>(L)->getValueAPF(); + const APFloat &RAPF = cast<ConstantFP>(R)->getValueAPF(); + return cmpAPFloats(LAPF, RAPF); + } + case Value::ConstantArrayVal: { + const ConstantArray *LA = cast<ConstantArray>(L); + const ConstantArray *RA = cast<ConstantArray>(R); + uint64_t NumElementsL = cast<ArrayType>(TyL)->getNumElements(); + uint64_t NumElementsR = cast<ArrayType>(TyR)->getNumElements(); + if (int Res = cmpNumbers(NumElementsL, NumElementsR)) + return Res; + for (uint64_t i = 0; i < NumElementsL; ++i) { + if (int Res = cmpConstants(cast<Constant>(LA->getOperand(i)), + cast<Constant>(RA->getOperand(i)))) + return Res; + } + return 0; + } + case Value::ConstantStructVal: { + const ConstantStruct *LS = cast<ConstantStruct>(L); + const ConstantStruct *RS = cast<ConstantStruct>(R); + unsigned NumElementsL = cast<StructType>(TyL)->getNumElements(); + unsigned NumElementsR = cast<StructType>(TyR)->getNumElements(); + if (int Res = cmpNumbers(NumElementsL, NumElementsR)) + return Res; + for (unsigned i = 0; i != NumElementsL; ++i) { + if (int Res = cmpConstants(cast<Constant>(LS->getOperand(i)), + cast<Constant>(RS->getOperand(i)))) + return Res; + } + return 0; + } + case Value::ConstantVectorVal: { + const ConstantVector *LV = cast<ConstantVector>(L); + const ConstantVector *RV = cast<ConstantVector>(R); + unsigned NumElementsL = cast<VectorType>(TyL)->getNumElements(); + unsigned NumElementsR = cast<VectorType>(TyR)->getNumElements(); + if (int Res = cmpNumbers(NumElementsL, NumElementsR)) + return Res; + for (uint64_t i = 0; i < NumElementsL; ++i) { + if (int Res = cmpConstants(cast<Constant>(LV->getOperand(i)), + cast<Constant>(RV->getOperand(i)))) + return Res; + } + return 0; + } + case Value::ConstantExprVal: { + const ConstantExpr *LE = cast<ConstantExpr>(L); + const ConstantExpr *RE = cast<ConstantExpr>(R); + unsigned NumOperandsL = LE->getNumOperands(); + unsigned NumOperandsR = RE->getNumOperands(); + if (int Res = cmpNumbers(NumOperandsL, NumOperandsR)) + return Res; + for (unsigned i = 0; i < NumOperandsL; ++i) { + if (int Res = cmpConstants(cast<Constant>(LE->getOperand(i)), + cast<Constant>(RE->getOperand(i)))) + return Res; + } + return 0; + } + case Value::BlockAddressVal: { + const BlockAddress *LBA = cast<BlockAddress>(L); + const BlockAddress *RBA = cast<BlockAddress>(R); + if (int Res = cmpValues(LBA->getFunction(), RBA->getFunction())) + return Res; + if (LBA->getFunction() == RBA->getFunction()) { + // They are BBs in the same function. Order by which comes first in the + // BB order of the function. This order is deterministic. + Function* F = LBA->getFunction(); + BasicBlock *LBB = LBA->getBasicBlock(); + BasicBlock *RBB = RBA->getBasicBlock(); + if (LBB == RBB) + return 0; + for(BasicBlock &BB : F->getBasicBlockList()) { + if (&BB == LBB) { + assert(&BB != RBB); + return -1; + } + if (&BB == RBB) + return 1; + } + llvm_unreachable("Basic Block Address does not point to a basic block in " + "its function."); + return -1; + } else { + // cmpValues said the functions are the same. So because they aren't + // literally the same pointer, they must respectively be the left and + // right functions. + assert(LBA->getFunction() == FnL && RBA->getFunction() == FnR); + // cmpValues will tell us if these are equivalent BasicBlocks, in the + // context of their respective functions. + return cmpValues(LBA->getBasicBlock(), RBA->getBasicBlock()); + } + } + default: // Unknown constant, abort. + DEBUG(dbgs() << "Looking at valueID " << L->getValueID() << "\n"); + llvm_unreachable("Constant ValueID not recognized."); + return -1; + } +} + +int FunctionComparator::cmpGlobalValues(GlobalValue *L, GlobalValue *R) const { + uint64_t LNumber = GlobalNumbers->getNumber(L); + uint64_t RNumber = GlobalNumbers->getNumber(R); + return cmpNumbers(LNumber, RNumber); +} + +/// cmpType - compares two types, +/// defines total ordering among the types set. +/// See method declaration comments for more details. +int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const { + PointerType *PTyL = dyn_cast<PointerType>(TyL); + PointerType *PTyR = dyn_cast<PointerType>(TyR); + + const DataLayout &DL = FnL->getParent()->getDataLayout(); + if (PTyL && PTyL->getAddressSpace() == 0) + TyL = DL.getIntPtrType(TyL); + if (PTyR && PTyR->getAddressSpace() == 0) + TyR = DL.getIntPtrType(TyR); + + if (TyL == TyR) + return 0; + + if (int Res = cmpNumbers(TyL->getTypeID(), TyR->getTypeID())) + return Res; + + switch (TyL->getTypeID()) { + default: + llvm_unreachable("Unknown type!"); + // Fall through in Release mode. + LLVM_FALLTHROUGH; + case Type::IntegerTyID: + return cmpNumbers(cast<IntegerType>(TyL)->getBitWidth(), + cast<IntegerType>(TyR)->getBitWidth()); + // TyL == TyR would have returned true earlier, because types are uniqued. + case Type::VoidTyID: + case Type::FloatTyID: + case Type::DoubleTyID: + case Type::X86_FP80TyID: + case Type::FP128TyID: + case Type::PPC_FP128TyID: + case Type::LabelTyID: + case Type::MetadataTyID: + case Type::TokenTyID: + return 0; + + case Type::PointerTyID: { + assert(PTyL && PTyR && "Both types must be pointers here."); + return cmpNumbers(PTyL->getAddressSpace(), PTyR->getAddressSpace()); + } + + case Type::StructTyID: { + StructType *STyL = cast<StructType>(TyL); + StructType *STyR = cast<StructType>(TyR); + if (STyL->getNumElements() != STyR->getNumElements()) + return cmpNumbers(STyL->getNumElements(), STyR->getNumElements()); + + if (STyL->isPacked() != STyR->isPacked()) + return cmpNumbers(STyL->isPacked(), STyR->isPacked()); + + for (unsigned i = 0, e = STyL->getNumElements(); i != e; ++i) { + if (int Res = cmpTypes(STyL->getElementType(i), STyR->getElementType(i))) + return Res; + } + return 0; + } + + case Type::FunctionTyID: { + FunctionType *FTyL = cast<FunctionType>(TyL); + FunctionType *FTyR = cast<FunctionType>(TyR); + if (FTyL->getNumParams() != FTyR->getNumParams()) + return cmpNumbers(FTyL->getNumParams(), FTyR->getNumParams()); + + if (FTyL->isVarArg() != FTyR->isVarArg()) + return cmpNumbers(FTyL->isVarArg(), FTyR->isVarArg()); + + if (int Res = cmpTypes(FTyL->getReturnType(), FTyR->getReturnType())) + return Res; + + for (unsigned i = 0, e = FTyL->getNumParams(); i != e; ++i) { + if (int Res = cmpTypes(FTyL->getParamType(i), FTyR->getParamType(i))) + return Res; + } + return 0; + } + + case Type::ArrayTyID: + case Type::VectorTyID: { + auto *STyL = cast<SequentialType>(TyL); + auto *STyR = cast<SequentialType>(TyR); + if (STyL->getNumElements() != STyR->getNumElements()) + return cmpNumbers(STyL->getNumElements(), STyR->getNumElements()); + return cmpTypes(STyL->getElementType(), STyR->getElementType()); + } + } +} + +// Determine whether the two operations are the same except that pointer-to-A +// and pointer-to-B are equivalent. This should be kept in sync with +// Instruction::isSameOperationAs. +// Read method declaration comments for more details. +int FunctionComparator::cmpOperations(const Instruction *L, + const Instruction *R, + bool &needToCmpOperands) const { + needToCmpOperands = true; + if (int Res = cmpValues(L, R)) + return Res; + + // Differences from Instruction::isSameOperationAs: + // * replace type comparison with calls to cmpTypes. + // * we test for I->getRawSubclassOptionalData (nuw/nsw/tail) at the top. + // * because of the above, we don't test for the tail bit on calls later on. + if (int Res = cmpNumbers(L->getOpcode(), R->getOpcode())) + return Res; + + if (const GetElementPtrInst *GEPL = dyn_cast<GetElementPtrInst>(L)) { + needToCmpOperands = false; + const GetElementPtrInst *GEPR = cast<GetElementPtrInst>(R); + if (int Res = + cmpValues(GEPL->getPointerOperand(), GEPR->getPointerOperand())) + return Res; + return cmpGEPs(GEPL, GEPR); + } + + if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands())) + return Res; + + if (int Res = cmpTypes(L->getType(), R->getType())) + return Res; + + if (int Res = cmpNumbers(L->getRawSubclassOptionalData(), + R->getRawSubclassOptionalData())) + return Res; + + // We have two instructions of identical opcode and #operands. Check to see + // if all operands are the same type + for (unsigned i = 0, e = L->getNumOperands(); i != e; ++i) { + if (int Res = + cmpTypes(L->getOperand(i)->getType(), R->getOperand(i)->getType())) + return Res; + } + + // Check special state that is a part of some instructions. + if (const AllocaInst *AI = dyn_cast<AllocaInst>(L)) { + if (int Res = cmpTypes(AI->getAllocatedType(), + cast<AllocaInst>(R)->getAllocatedType())) + return Res; + return cmpNumbers(AI->getAlignment(), cast<AllocaInst>(R)->getAlignment()); + } + if (const LoadInst *LI = dyn_cast<LoadInst>(L)) { + if (int Res = cmpNumbers(LI->isVolatile(), cast<LoadInst>(R)->isVolatile())) + return Res; + if (int Res = + cmpNumbers(LI->getAlignment(), cast<LoadInst>(R)->getAlignment())) + return Res; + if (int Res = + cmpOrderings(LI->getOrdering(), cast<LoadInst>(R)->getOrdering())) + return Res; + if (int Res = + cmpNumbers(LI->getSynchScope(), cast<LoadInst>(R)->getSynchScope())) + return Res; + return cmpRangeMetadata(LI->getMetadata(LLVMContext::MD_range), + cast<LoadInst>(R)->getMetadata(LLVMContext::MD_range)); + } + if (const StoreInst *SI = dyn_cast<StoreInst>(L)) { + if (int Res = + cmpNumbers(SI->isVolatile(), cast<StoreInst>(R)->isVolatile())) + return Res; + if (int Res = + cmpNumbers(SI->getAlignment(), cast<StoreInst>(R)->getAlignment())) + return Res; + if (int Res = + cmpOrderings(SI->getOrdering(), cast<StoreInst>(R)->getOrdering())) + return Res; + return cmpNumbers(SI->getSynchScope(), cast<StoreInst>(R)->getSynchScope()); + } + if (const CmpInst *CI = dyn_cast<CmpInst>(L)) + return cmpNumbers(CI->getPredicate(), cast<CmpInst>(R)->getPredicate()); + if (const CallInst *CI = dyn_cast<CallInst>(L)) { + if (int Res = cmpNumbers(CI->getCallingConv(), + cast<CallInst>(R)->getCallingConv())) + return Res; + if (int Res = + cmpAttrs(CI->getAttributes(), cast<CallInst>(R)->getAttributes())) + return Res; + if (int Res = cmpOperandBundlesSchema(CI, R)) + return Res; + return cmpRangeMetadata( + CI->getMetadata(LLVMContext::MD_range), + cast<CallInst>(R)->getMetadata(LLVMContext::MD_range)); + } + if (const InvokeInst *II = dyn_cast<InvokeInst>(L)) { + if (int Res = cmpNumbers(II->getCallingConv(), + cast<InvokeInst>(R)->getCallingConv())) + return Res; + if (int Res = + cmpAttrs(II->getAttributes(), cast<InvokeInst>(R)->getAttributes())) + return Res; + if (int Res = cmpOperandBundlesSchema(II, R)) + return Res; + return cmpRangeMetadata( + II->getMetadata(LLVMContext::MD_range), + cast<InvokeInst>(R)->getMetadata(LLVMContext::MD_range)); + } + if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(L)) { + ArrayRef<unsigned> LIndices = IVI->getIndices(); + ArrayRef<unsigned> RIndices = cast<InsertValueInst>(R)->getIndices(); + if (int Res = cmpNumbers(LIndices.size(), RIndices.size())) + return Res; + for (size_t i = 0, e = LIndices.size(); i != e; ++i) { + if (int Res = cmpNumbers(LIndices[i], RIndices[i])) + return Res; + } + return 0; + } + if (const ExtractValueInst *EVI = dyn_cast<ExtractValueInst>(L)) { + ArrayRef<unsigned> LIndices = EVI->getIndices(); + ArrayRef<unsigned> RIndices = cast<ExtractValueInst>(R)->getIndices(); + if (int Res = cmpNumbers(LIndices.size(), RIndices.size())) + return Res; + for (size_t i = 0, e = LIndices.size(); i != e; ++i) { + if (int Res = cmpNumbers(LIndices[i], RIndices[i])) + return Res; + } + } + if (const FenceInst *FI = dyn_cast<FenceInst>(L)) { + if (int Res = + cmpOrderings(FI->getOrdering(), cast<FenceInst>(R)->getOrdering())) + return Res; + return cmpNumbers(FI->getSynchScope(), cast<FenceInst>(R)->getSynchScope()); + } + if (const AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(L)) { + if (int Res = cmpNumbers(CXI->isVolatile(), + cast<AtomicCmpXchgInst>(R)->isVolatile())) + return Res; + if (int Res = cmpNumbers(CXI->isWeak(), + cast<AtomicCmpXchgInst>(R)->isWeak())) + return Res; + if (int Res = + cmpOrderings(CXI->getSuccessOrdering(), + cast<AtomicCmpXchgInst>(R)->getSuccessOrdering())) + return Res; + if (int Res = + cmpOrderings(CXI->getFailureOrdering(), + cast<AtomicCmpXchgInst>(R)->getFailureOrdering())) + return Res; + return cmpNumbers(CXI->getSynchScope(), + cast<AtomicCmpXchgInst>(R)->getSynchScope()); + } + if (const AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(L)) { + if (int Res = cmpNumbers(RMWI->getOperation(), + cast<AtomicRMWInst>(R)->getOperation())) + return Res; + if (int Res = cmpNumbers(RMWI->isVolatile(), + cast<AtomicRMWInst>(R)->isVolatile())) + return Res; + if (int Res = cmpOrderings(RMWI->getOrdering(), + cast<AtomicRMWInst>(R)->getOrdering())) + return Res; + return cmpNumbers(RMWI->getSynchScope(), + cast<AtomicRMWInst>(R)->getSynchScope()); + } + if (const PHINode *PNL = dyn_cast<PHINode>(L)) { + const PHINode *PNR = cast<PHINode>(R); + // Ensure that in addition to the incoming values being identical + // (checked by the caller of this function), the incoming blocks + // are also identical. + for (unsigned i = 0, e = PNL->getNumIncomingValues(); i != e; ++i) { + if (int Res = + cmpValues(PNL->getIncomingBlock(i), PNR->getIncomingBlock(i))) + return Res; + } + } + return 0; +} + +// Determine whether two GEP operations perform the same underlying arithmetic. +// Read method declaration comments for more details. +int FunctionComparator::cmpGEPs(const GEPOperator *GEPL, + const GEPOperator *GEPR) const { + + unsigned int ASL = GEPL->getPointerAddressSpace(); + unsigned int ASR = GEPR->getPointerAddressSpace(); + + if (int Res = cmpNumbers(ASL, ASR)) + return Res; + + // When we have target data, we can reduce the GEP down to the value in bytes + // added to the address. + const DataLayout &DL = FnL->getParent()->getDataLayout(); + unsigned BitWidth = DL.getPointerSizeInBits(ASL); + APInt OffsetL(BitWidth, 0), OffsetR(BitWidth, 0); + if (GEPL->accumulateConstantOffset(DL, OffsetL) && + GEPR->accumulateConstantOffset(DL, OffsetR)) + return cmpAPInts(OffsetL, OffsetR); + if (int Res = cmpTypes(GEPL->getSourceElementType(), + GEPR->getSourceElementType())) + return Res; + + if (int Res = cmpNumbers(GEPL->getNumOperands(), GEPR->getNumOperands())) + return Res; + + for (unsigned i = 0, e = GEPL->getNumOperands(); i != e; ++i) { + if (int Res = cmpValues(GEPL->getOperand(i), GEPR->getOperand(i))) + return Res; + } + + return 0; +} + +int FunctionComparator::cmpInlineAsm(const InlineAsm *L, + const InlineAsm *R) const { + // InlineAsm's are uniqued. If they are the same pointer, obviously they are + // the same, otherwise compare the fields. + if (L == R) + return 0; + if (int Res = cmpTypes(L->getFunctionType(), R->getFunctionType())) + return Res; + if (int Res = cmpMem(L->getAsmString(), R->getAsmString())) + return Res; + if (int Res = cmpMem(L->getConstraintString(), R->getConstraintString())) + return Res; + if (int Res = cmpNumbers(L->hasSideEffects(), R->hasSideEffects())) + return Res; + if (int Res = cmpNumbers(L->isAlignStack(), R->isAlignStack())) + return Res; + if (int Res = cmpNumbers(L->getDialect(), R->getDialect())) + return Res; + llvm_unreachable("InlineAsm blocks were not uniqued."); + return 0; +} + +/// Compare two values used by the two functions under pair-wise comparison. If +/// this is the first time the values are seen, they're added to the mapping so +/// that we will detect mismatches on next use. +/// See comments in declaration for more details. +int FunctionComparator::cmpValues(const Value *L, const Value *R) const { + // Catch self-reference case. + if (L == FnL) { + if (R == FnR) + return 0; + return -1; + } + if (R == FnR) { + if (L == FnL) + return 0; + return 1; + } + + const Constant *ConstL = dyn_cast<Constant>(L); + const Constant *ConstR = dyn_cast<Constant>(R); + if (ConstL && ConstR) { + if (L == R) + return 0; + return cmpConstants(ConstL, ConstR); + } + + if (ConstL) + return 1; + if (ConstR) + return -1; + + const InlineAsm *InlineAsmL = dyn_cast<InlineAsm>(L); + const InlineAsm *InlineAsmR = dyn_cast<InlineAsm>(R); + + if (InlineAsmL && InlineAsmR) + return cmpInlineAsm(InlineAsmL, InlineAsmR); + if (InlineAsmL) + return 1; + if (InlineAsmR) + return -1; + + auto LeftSN = sn_mapL.insert(std::make_pair(L, sn_mapL.size())), + RightSN = sn_mapR.insert(std::make_pair(R, sn_mapR.size())); + + return cmpNumbers(LeftSN.first->second, RightSN.first->second); +} + +// Test whether two basic blocks have equivalent behaviour. +int FunctionComparator::cmpBasicBlocks(const BasicBlock *BBL, + const BasicBlock *BBR) const { + BasicBlock::const_iterator InstL = BBL->begin(), InstLE = BBL->end(); + BasicBlock::const_iterator InstR = BBR->begin(), InstRE = BBR->end(); + + do { + bool needToCmpOperands = true; + if (int Res = cmpOperations(&*InstL, &*InstR, needToCmpOperands)) + return Res; + if (needToCmpOperands) { + assert(InstL->getNumOperands() == InstR->getNumOperands()); + + for (unsigned i = 0, e = InstL->getNumOperands(); i != e; ++i) { + Value *OpL = InstL->getOperand(i); + Value *OpR = InstR->getOperand(i); + if (int Res = cmpValues(OpL, OpR)) + return Res; + // cmpValues should ensure this is true. + assert(cmpTypes(OpL->getType(), OpR->getType()) == 0); + } + } + + ++InstL; + ++InstR; + } while (InstL != InstLE && InstR != InstRE); + + if (InstL != InstLE && InstR == InstRE) + return 1; + if (InstL == InstLE && InstR != InstRE) + return -1; + return 0; +} + +int FunctionComparator::compareSignature() const { + if (int Res = cmpAttrs(FnL->getAttributes(), FnR->getAttributes())) + return Res; + + if (int Res = cmpNumbers(FnL->hasGC(), FnR->hasGC())) + return Res; + + if (FnL->hasGC()) { + if (int Res = cmpMem(FnL->getGC(), FnR->getGC())) + return Res; + } + + if (int Res = cmpNumbers(FnL->hasSection(), FnR->hasSection())) + return Res; + + if (FnL->hasSection()) { + if (int Res = cmpMem(FnL->getSection(), FnR->getSection())) + return Res; + } + + if (int Res = cmpNumbers(FnL->isVarArg(), FnR->isVarArg())) + return Res; + + // TODO: if it's internal and only used in direct calls, we could handle this + // case too. + if (int Res = cmpNumbers(FnL->getCallingConv(), FnR->getCallingConv())) + return Res; + + if (int Res = cmpTypes(FnL->getFunctionType(), FnR->getFunctionType())) + return Res; + + assert(FnL->arg_size() == FnR->arg_size() && + "Identically typed functions have different numbers of args!"); + + // Visit the arguments so that they get enumerated in the order they're + // passed in. + for (Function::const_arg_iterator ArgLI = FnL->arg_begin(), + ArgRI = FnR->arg_begin(), + ArgLE = FnL->arg_end(); + ArgLI != ArgLE; ++ArgLI, ++ArgRI) { + if (cmpValues(&*ArgLI, &*ArgRI) != 0) + llvm_unreachable("Arguments repeat!"); + } + return 0; +} + +// Test whether the two functions have equivalent behaviour. +int FunctionComparator::compare() { + beginCompare(); + + if (int Res = compareSignature()) + return Res; + + // We do a CFG-ordered walk since the actual ordering of the blocks in the + // linked list is immaterial. Our walk starts at the entry block for both + // functions, then takes each block from each terminator in order. As an + // artifact, this also means that unreachable blocks are ignored. + SmallVector<const BasicBlock *, 8> FnLBBs, FnRBBs; + SmallPtrSet<const BasicBlock *, 32> VisitedBBs; // in terms of F1. + + FnLBBs.push_back(&FnL->getEntryBlock()); + FnRBBs.push_back(&FnR->getEntryBlock()); + + VisitedBBs.insert(FnLBBs[0]); + while (!FnLBBs.empty()) { + const BasicBlock *BBL = FnLBBs.pop_back_val(); + const BasicBlock *BBR = FnRBBs.pop_back_val(); + + if (int Res = cmpValues(BBL, BBR)) + return Res; + + if (int Res = cmpBasicBlocks(BBL, BBR)) + return Res; + + const TerminatorInst *TermL = BBL->getTerminator(); + const TerminatorInst *TermR = BBR->getTerminator(); + + assert(TermL->getNumSuccessors() == TermR->getNumSuccessors()); + for (unsigned i = 0, e = TermL->getNumSuccessors(); i != e; ++i) { + if (!VisitedBBs.insert(TermL->getSuccessor(i)).second) + continue; + + FnLBBs.push_back(TermL->getSuccessor(i)); + FnRBBs.push_back(TermR->getSuccessor(i)); + } + } + return 0; +} + +namespace { + +// Accumulate the hash of a sequence of 64-bit integers. This is similar to a +// hash of a sequence of 64bit ints, but the entire input does not need to be +// available at once. This interface is necessary for functionHash because it +// needs to accumulate the hash as the structure of the function is traversed +// without saving these values to an intermediate buffer. This form of hashing +// is not often needed, as usually the object to hash is just read from a +// buffer. +class HashAccumulator64 { + uint64_t Hash; +public: + // Initialize to random constant, so the state isn't zero. + HashAccumulator64() { Hash = 0x6acaa36bef8325c5ULL; } + void add(uint64_t V) { + Hash = llvm::hashing::detail::hash_16_bytes(Hash, V); + } + // No finishing is required, because the entire hash value is used. + uint64_t getHash() { return Hash; } +}; +} // end anonymous namespace + +// A function hash is calculated by considering only the number of arguments and +// whether a function is varargs, the order of basic blocks (given by the +// successors of each basic block in depth first order), and the order of +// opcodes of each instruction within each of these basic blocks. This mirrors +// the strategy compare() uses to compare functions by walking the BBs in depth +// first order and comparing each instruction in sequence. Because this hash +// does not look at the operands, it is insensitive to things such as the +// target of calls and the constants used in the function, which makes it useful +// when possibly merging functions which are the same modulo constants and call +// targets. +FunctionComparator::FunctionHash FunctionComparator::functionHash(Function &F) { + HashAccumulator64 H; + H.add(F.isVarArg()); + H.add(F.arg_size()); + + SmallVector<const BasicBlock *, 8> BBs; + SmallSet<const BasicBlock *, 16> VisitedBBs; + + // Walk the blocks in the same order as FunctionComparator::cmpBasicBlocks(), + // accumulating the hash of the function "structure." (BB and opcode sequence) + BBs.push_back(&F.getEntryBlock()); + VisitedBBs.insert(BBs[0]); + while (!BBs.empty()) { + const BasicBlock *BB = BBs.pop_back_val(); + // This random value acts as a block header, as otherwise the partition of + // opcodes into BBs wouldn't affect the hash, only the order of the opcodes + H.add(45798); + for (auto &Inst : *BB) { + H.add(Inst.getOpcode()); + } + const TerminatorInst *Term = BB->getTerminator(); + for (unsigned i = 0, e = Term->getNumSuccessors(); i != e; ++i) { + if (!VisitedBBs.insert(Term->getSuccessor(i)).second) + continue; + BBs.push_back(Term->getSuccessor(i)); + } + } + return H.getHash(); +} + + diff --git a/contrib/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp b/contrib/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp index fcb25ba..9844190 100644 --- a/contrib/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp +++ b/contrib/llvm/lib/Transforms/Utils/FunctionImportUtils.cpp @@ -48,7 +48,7 @@ bool FunctionImportGlobalProcessing::doImportAsDefinition( GlobalsToImport); } -bool FunctionImportGlobalProcessing::doPromoteLocalToGlobal( +bool FunctionImportGlobalProcessing::shouldPromoteLocalToGlobal( const GlobalValue *SGV) { assert(SGV->hasLocalLinkage()); // Both the imported references and the original local variable must @@ -56,36 +56,57 @@ bool FunctionImportGlobalProcessing::doPromoteLocalToGlobal( if (!isPerformingImport() && !isModuleExporting()) return false; - // Local const variables never need to be promoted unless they are address - // taken. The imported uses can simply use the clone created in this module. - // For now we are conservative in determining which variables are not - // address taken by checking the unnamed addr flag. To be more aggressive, - // the address taken information must be checked earlier during parsing - // of the module and recorded in the summary index for use when importing - // from that module. - auto *GVar = dyn_cast<GlobalVariable>(SGV); - if (GVar && GVar->isConstant() && GVar->hasGlobalUnnamedAddr()) - return false; + if (isPerformingImport()) { + assert((!GlobalsToImport->count(SGV) || !isNonRenamableLocal(*SGV)) && + "Attempting to promote non-renamable local"); + // We don't know for sure yet if we are importing this value (as either + // a reference or a def), since we are simply walking all values in the + // module. But by necessity if we end up importing it and it is local, + // it must be promoted, so unconditionally promote all values in the + // importing module. + return true; + } - if (GVar && GVar->hasSection()) - // Some sections like "__DATA,__cfstring" are "magic" and promotion is not - // allowed. Just disable promotion on any GVar with sections right now. - return false; + // When exporting, consult the index. We can have more than one local + // with the same GUID, in the case of same-named locals in different but + // same-named source files that were compiled in their respective directories + // (so the source file name and resulting GUID is the same). Find the one + // in this module. + auto Summary = ImportIndex.findSummaryInModule( + SGV->getGUID(), SGV->getParent()->getModuleIdentifier()); + assert(Summary && "Missing summary for global value when exporting"); + auto Linkage = Summary->linkage(); + if (!GlobalValue::isLocalLinkage(Linkage)) { + assert(!isNonRenamableLocal(*SGV) && + "Attempting to promote non-renamable local"); + return true; + } - // Eventually we only need to promote functions in the exporting module that - // are referenced by a potentially exported function (i.e. one that is in the - // summary index). - return true; + return false; } -std::string FunctionImportGlobalProcessing::getName(const GlobalValue *SGV) { +#ifndef NDEBUG +bool FunctionImportGlobalProcessing::isNonRenamableLocal( + const GlobalValue &GV) const { + if (!GV.hasLocalLinkage()) + return false; + // This needs to stay in sync with the logic in buildModuleSummaryIndex. + if (GV.hasSection()) + return true; + if (Used.count(const_cast<GlobalValue *>(&GV))) + return true; + return false; +} +#endif + +std::string FunctionImportGlobalProcessing::getName(const GlobalValue *SGV, + bool DoPromote) { // For locals that must be promoted to global scope, ensure that // the promoted name uniquely identifies the copy in the original module, // using the ID assigned during combined index creation. When importing, // we rename all locals (not just those that are promoted) in order to // avoid naming conflicts between locals imported from different modules. - if (SGV->hasLocalLinkage() && - (doPromoteLocalToGlobal(SGV) || isPerformingImport())) + if (SGV->hasLocalLinkage() && (DoPromote || isPerformingImport())) return ModuleSummaryIndex::getGlobalNameForLocal( SGV->getName(), ImportIndex.getModuleHash(SGV->getParent()->getModuleIdentifier())); @@ -93,13 +114,14 @@ std::string FunctionImportGlobalProcessing::getName(const GlobalValue *SGV) { } GlobalValue::LinkageTypes -FunctionImportGlobalProcessing::getLinkage(const GlobalValue *SGV) { +FunctionImportGlobalProcessing::getLinkage(const GlobalValue *SGV, + bool DoPromote) { // Any local variable that is referenced by an exported function needs // to be promoted to global scope. Since we don't currently know which // functions reference which local variables/functions, we must treat // all as potentially exported if this module is exporting anything. if (isModuleExporting()) { - if (SGV->hasLocalLinkage() && doPromoteLocalToGlobal(SGV)) + if (SGV->hasLocalLinkage() && DoPromote) return GlobalValue::ExternalLinkage; return SGV->getLinkage(); } @@ -164,7 +186,7 @@ FunctionImportGlobalProcessing::getLinkage(const GlobalValue *SGV) { case GlobalValue::PrivateLinkage: // If we are promoting the local to global scope, it is handled // similarly to a normal externally visible global. - if (doPromoteLocalToGlobal(SGV)) { + if (DoPromote) { if (doImportAsDefinition(SGV) && !dyn_cast<GlobalAlias>(SGV)) return GlobalValue::AvailableExternallyLinkage; else @@ -190,14 +212,19 @@ FunctionImportGlobalProcessing::getLinkage(const GlobalValue *SGV) { } void FunctionImportGlobalProcessing::processGlobalForThinLTO(GlobalValue &GV) { + bool DoPromote = false; if (GV.hasLocalLinkage() && - (doPromoteLocalToGlobal(&GV) || isPerformingImport())) { - GV.setName(getName(&GV)); - GV.setLinkage(getLinkage(&GV)); + ((DoPromote = shouldPromoteLocalToGlobal(&GV)) || isPerformingImport())) { + // Once we change the name or linkage it is difficult to determine + // again whether we should promote since shouldPromoteLocalToGlobal needs + // to locate the summary (based on GUID from name and linkage). Therefore, + // use DoPromote result saved above. + GV.setName(getName(&GV, DoPromote)); + GV.setLinkage(getLinkage(&GV, DoPromote)); if (!GV.hasLocalLinkage()) GV.setVisibility(GlobalValue::HiddenVisibility); } else - GV.setLinkage(getLinkage(&GV)); + GV.setLinkage(getLinkage(&GV, /* DoPromote */ false)); // Remove functions imported as available externally defs from comdats, // as this is a declaration for the linker, and will be dropped eventually. @@ -214,14 +241,6 @@ void FunctionImportGlobalProcessing::processGlobalForThinLTO(GlobalValue &GV) { } void FunctionImportGlobalProcessing::processGlobalsForThinLTO() { - if (!moduleCanBeRenamedForThinLTO(M)) { - // We would have blocked importing from this module by suppressing index - // generation. We still may be able to import into this module though. - assert(!isPerformingImport() && - "Should have blocked importing from module with local used in ASM"); - return; - } - for (GlobalVariable &GV : M.globals()) processGlobalForThinLTO(GV); for (Function &SF : M) diff --git a/contrib/llvm/lib/Transforms/Utils/GlobalStatus.cpp b/contrib/llvm/lib/Transforms/Utils/GlobalStatus.cpp index 266be41..74ebcda 100644 --- a/contrib/llvm/lib/Transforms/Utils/GlobalStatus.cpp +++ b/contrib/llvm/lib/Transforms/Utils/GlobalStatus.cpp @@ -20,9 +20,8 @@ using namespace llvm; /// and release, then return AcquireRelease. /// static AtomicOrdering strongerOrdering(AtomicOrdering X, AtomicOrdering Y) { - if (X == AtomicOrdering::Acquire && Y == AtomicOrdering::Release) - return AtomicOrdering::AcquireRelease; - if (Y == AtomicOrdering::Acquire && X == AtomicOrdering::Release) + if ((X == AtomicOrdering::Acquire && Y == AtomicOrdering::Release) || + (Y == AtomicOrdering::Acquire && X == AtomicOrdering::Release)) return AtomicOrdering::AcquireRelease; return (AtomicOrdering)std::max((unsigned)X, (unsigned)Y); } @@ -35,7 +34,7 @@ bool llvm::isSafeToDestroyConstant(const Constant *C) { if (isa<GlobalValue>(C)) return false; - if (isa<ConstantInt>(C) || isa<ConstantFP>(C)) + if (isa<ConstantData>(C)) return false; for (const User *U : C->users()) diff --git a/contrib/llvm/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp b/contrib/llvm/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp new file mode 100644 index 0000000..ed018bb --- /dev/null +++ b/contrib/llvm/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp @@ -0,0 +1,203 @@ +//===-- ImportedFunctionsInliningStats.cpp ----------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// Generating inliner statistics for imported functions, mostly useful for +// ThinLTO. +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/ImportedFunctionsInliningStatistics.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <iomanip> +#include <sstream> +using namespace llvm; + +ImportedFunctionsInliningStatistics::InlineGraphNode & +ImportedFunctionsInliningStatistics::createInlineGraphNode(const Function &F) { + + auto &ValueLookup = NodesMap[F.getName()]; + if (!ValueLookup) { + ValueLookup = llvm::make_unique<InlineGraphNode>(); + ValueLookup->Imported = F.getMetadata("thinlto_src_module") != nullptr; + } + return *ValueLookup; +} + +void ImportedFunctionsInliningStatistics::recordInline(const Function &Caller, + const Function &Callee) { + + InlineGraphNode &CallerNode = createInlineGraphNode(Caller); + InlineGraphNode &CalleeNode = createInlineGraphNode(Callee); + CalleeNode.NumberOfInlines++; + + if (!CallerNode.Imported && !CalleeNode.Imported) { + // Direct inline from not imported callee to not imported caller, so we + // don't have to add this to graph. It might be very helpful if you wanna + // get the inliner statistics in compile step where there are no imported + // functions. In this case the graph would be empty. + CalleeNode.NumberOfRealInlines++; + return; + } + + CallerNode.InlinedCallees.push_back(&CalleeNode); + if (!CallerNode.Imported) { + // We could avoid second lookup, but it would make the code ultra ugly. + auto It = NodesMap.find(Caller.getName()); + assert(It != NodesMap.end() && "The node should be already there."); + // Save Caller as a starting node for traversal. The string has to be one + // from map because Caller can disappear (and function name with it). + NonImportedCallers.push_back(It->first()); + } +} + +void ImportedFunctionsInliningStatistics::setModuleInfo(const Module &M) { + ModuleName = M.getName(); + for (const auto &F : M.functions()) { + AllFunctions++; + ImportedFunctions += int(F.getMetadata("thinlto_src_module") != nullptr); + } +} +static std::string getStatString(const char *Msg, int32_t Fraction, int32_t All, + const char *PercentageOfMsg, + bool LineEnd = true) { + double Result = 0; + if (All != 0) + Result = 100 * static_cast<double>(Fraction) / All; + + std::stringstream Str; + Str << std::setprecision(4) << Msg << ": " << Fraction << " [" << Result + << "% of " << PercentageOfMsg << "]"; + if (LineEnd) + Str << "\n"; + return Str.str(); +} + +void ImportedFunctionsInliningStatistics::dump(const bool Verbose) { + calculateRealInlines(); + NonImportedCallers.clear(); + + int32_t InlinedImportedFunctionsCount = 0; + int32_t InlinedNotImportedFunctionsCount = 0; + + int32_t InlinedImportedFunctionsToImportingModuleCount = 0; + int32_t InlinedNotImportedFunctionsToImportingModuleCount = 0; + + const auto SortedNodes = getSortedNodes(); + std::string Out; + Out.reserve(5000); + raw_string_ostream Ostream(Out); + + Ostream << "------- Dumping inliner stats for [" << ModuleName + << "] -------\n"; + + if (Verbose) + Ostream << "-- List of inlined functions:\n"; + + for (const auto &Node : SortedNodes) { + assert(Node->second->NumberOfInlines >= Node->second->NumberOfRealInlines); + if (Node->second->NumberOfInlines == 0) + continue; + + if (Node->second->Imported) { + InlinedImportedFunctionsCount++; + InlinedImportedFunctionsToImportingModuleCount += + int(Node->second->NumberOfRealInlines > 0); + } else { + InlinedNotImportedFunctionsCount++; + InlinedNotImportedFunctionsToImportingModuleCount += + int(Node->second->NumberOfRealInlines > 0); + } + + if (Verbose) + Ostream << "Inlined " + << (Node->second->Imported ? "imported " : "not imported ") + << "function [" << Node->first() << "]" + << ": #inlines = " << Node->second->NumberOfInlines + << ", #inlines_to_importing_module = " + << Node->second->NumberOfRealInlines << "\n"; + } + + auto InlinedFunctionsCount = + InlinedImportedFunctionsCount + InlinedNotImportedFunctionsCount; + auto NotImportedFuncCount = AllFunctions - ImportedFunctions; + auto ImportedNotInlinedIntoModule = + ImportedFunctions - InlinedImportedFunctionsToImportingModuleCount; + + Ostream << "-- Summary:\n" + << "All functions: " << AllFunctions + << ", imported functions: " << ImportedFunctions << "\n" + << getStatString("inlined functions", InlinedFunctionsCount, + AllFunctions, "all functions") + << getStatString("imported functions inlined anywhere", + InlinedImportedFunctionsCount, ImportedFunctions, + "imported functions") + << getStatString("imported functions inlined into importing module", + InlinedImportedFunctionsToImportingModuleCount, + ImportedFunctions, "imported functions", + /*LineEnd=*/false) + << getStatString(", remaining", ImportedNotInlinedIntoModule, + ImportedFunctions, "imported functions") + << getStatString("non-imported functions inlined anywhere", + InlinedNotImportedFunctionsCount, + NotImportedFuncCount, "non-imported functions") + << getStatString( + "non-imported functions inlined into importing module", + InlinedNotImportedFunctionsToImportingModuleCount, + NotImportedFuncCount, "non-imported functions"); + Ostream.flush(); + dbgs() << Out; +} + +void ImportedFunctionsInliningStatistics::calculateRealInlines() { + // Removing duplicated Callers. + std::sort(NonImportedCallers.begin(), NonImportedCallers.end()); + NonImportedCallers.erase( + std::unique(NonImportedCallers.begin(), NonImportedCallers.end()), + NonImportedCallers.end()); + + for (const auto &Name : NonImportedCallers) { + auto &Node = *NodesMap[Name]; + if (!Node.Visited) + dfs(Node); + } +} + +void ImportedFunctionsInliningStatistics::dfs(InlineGraphNode &GraphNode) { + assert(!GraphNode.Visited); + GraphNode.Visited = true; + for (auto *const InlinedFunctionNode : GraphNode.InlinedCallees) { + InlinedFunctionNode->NumberOfRealInlines++; + if (!InlinedFunctionNode->Visited) + dfs(*InlinedFunctionNode); + } +} + +ImportedFunctionsInliningStatistics::SortedNodesTy +ImportedFunctionsInliningStatistics::getSortedNodes() { + SortedNodesTy SortedNodes; + SortedNodes.reserve(NodesMap.size()); + for (const NodesMapTy::value_type& Node : NodesMap) + SortedNodes.push_back(&Node); + + std::sort( + SortedNodes.begin(), SortedNodes.end(), + [&](const SortedNodesTy::value_type &Lhs, + const SortedNodesTy::value_type &Rhs) { + if (Lhs->second->NumberOfInlines != Rhs->second->NumberOfInlines) + return Lhs->second->NumberOfInlines > Rhs->second->NumberOfInlines; + if (Lhs->second->NumberOfRealInlines != Rhs->second->NumberOfRealInlines) + return Lhs->second->NumberOfRealInlines > + Rhs->second->NumberOfRealInlines; + return Lhs->first() < Rhs->first(); + }); + return SortedNodes; +} diff --git a/contrib/llvm/lib/Transforms/Utils/InlineFunction.cpp b/contrib/llvm/lib/Transforms/Utils/InlineFunction.cpp index e82c07f..a40079c 100644 --- a/contrib/llvm/lib/Transforms/Utils/InlineFunction.cpp +++ b/contrib/llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -14,6 +14,7 @@ #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" @@ -228,7 +229,7 @@ static Value *getUnwindDestTokenHelper(Instruction *EHPad, Instruction *ChildPad = cast<Instruction>(Child); auto Memo = MemoMap.find(ChildPad); if (Memo == MemoMap.end()) { - // Haven't figure out this child pad yet; queue it. + // Haven't figured out this child pad yet; queue it. Worklist.push_back(ChildPad); continue; } @@ -366,6 +367,10 @@ static Value *getUnwindDestToken(Instruction *EHPad, // search up the chain to try to find a funclet with information. Put // null entries in the memo map to avoid re-processing as we go up. MemoMap[EHPad] = nullptr; +#ifndef NDEBUG + SmallPtrSet<Instruction *, 4> TempMemos; + TempMemos.insert(EHPad); +#endif Instruction *LastUselessPad = EHPad; Value *AncestorToken; for (AncestorToken = getParentPad(EHPad); @@ -374,6 +379,13 @@ static Value *getUnwindDestToken(Instruction *EHPad, // Skip over catchpads since they just follow their catchswitches. if (isa<CatchPadInst>(AncestorPad)) continue; + // If the MemoMap had an entry mapping AncestorPad to nullptr, since we + // haven't yet called getUnwindDestTokenHelper for AncestorPad in this + // call to getUnwindDestToken, that would mean that AncestorPad had no + // information in itself, its descendants, or its ancestors. If that + // were the case, then we should also have recorded the lack of information + // for the descendant that we're coming from. So assert that we don't + // find a null entry in the MemoMap for AncestorPad. assert(!MemoMap.count(AncestorPad) || MemoMap[AncestorPad]); auto AncestorMemo = MemoMap.find(AncestorPad); if (AncestorMemo == MemoMap.end()) { @@ -384,25 +396,85 @@ static Value *getUnwindDestToken(Instruction *EHPad, if (UnwindDestToken) break; LastUselessPad = AncestorPad; + MemoMap[LastUselessPad] = nullptr; +#ifndef NDEBUG + TempMemos.insert(LastUselessPad); +#endif } - // Since the whole tree under LastUselessPad has no information, it all must - // match UnwindDestToken; record that to avoid repeating the search. + // We know that getUnwindDestTokenHelper was called on LastUselessPad and + // returned nullptr (and likewise for EHPad and any of its ancestors up to + // LastUselessPad), so LastUselessPad has no information from below. Since + // getUnwindDestTokenHelper must investigate all downward paths through + // no-information nodes to prove that a node has no information like this, + // and since any time it finds information it records it in the MemoMap for + // not just the immediately-containing funclet but also any ancestors also + // exited, it must be the case that, walking downward from LastUselessPad, + // visiting just those nodes which have not been mapped to an unwind dest + // by getUnwindDestTokenHelper (the nullptr TempMemos notwithstanding, since + // they are just used to keep getUnwindDestTokenHelper from repeating work), + // any node visited must have been exhaustively searched with no information + // for it found. SmallVector<Instruction *, 8> Worklist(1, LastUselessPad); while (!Worklist.empty()) { Instruction *UselessPad = Worklist.pop_back_val(); - assert(!MemoMap.count(UselessPad) || MemoMap[UselessPad] == nullptr); + auto Memo = MemoMap.find(UselessPad); + if (Memo != MemoMap.end() && Memo->second) { + // Here the name 'UselessPad' is a bit of a misnomer, because we've found + // that it is a funclet that does have information about unwinding to + // a particular destination; its parent was a useless pad. + // Since its parent has no information, the unwind edge must not escape + // the parent, and must target a sibling of this pad. This local unwind + // gives us no information about EHPad. Leave it and the subtree rooted + // at it alone. + assert(getParentPad(Memo->second) == getParentPad(UselessPad)); + continue; + } + // We know we don't have information for UselesPad. If it has an entry in + // the MemoMap (mapping it to nullptr), it must be one of the TempMemos + // added on this invocation of getUnwindDestToken; if a previous invocation + // recorded nullptr, it would have had to prove that the ancestors of + // UselessPad, which include LastUselessPad, had no information, and that + // in turn would have required proving that the descendants of + // LastUselesPad, which include EHPad, have no information about + // LastUselessPad, which would imply that EHPad was mapped to nullptr in + // the MemoMap on that invocation, which isn't the case if we got here. + assert(!MemoMap.count(UselessPad) || TempMemos.count(UselessPad)); + // Assert as we enumerate users that 'UselessPad' doesn't have any unwind + // information that we'd be contradicting by making a map entry for it + // (which is something that getUnwindDestTokenHelper must have proved for + // us to get here). Just assert on is direct users here; the checks in + // this downward walk at its descendants will verify that they don't have + // any unwind edges that exit 'UselessPad' either (i.e. they either have no + // unwind edges or unwind to a sibling). MemoMap[UselessPad] = UnwindDestToken; if (auto *CatchSwitch = dyn_cast<CatchSwitchInst>(UselessPad)) { - for (BasicBlock *HandlerBlock : CatchSwitch->handlers()) - for (User *U : HandlerBlock->getFirstNonPHI()->users()) + assert(CatchSwitch->getUnwindDest() == nullptr && "Expected useless pad"); + for (BasicBlock *HandlerBlock : CatchSwitch->handlers()) { + auto *CatchPad = HandlerBlock->getFirstNonPHI(); + for (User *U : CatchPad->users()) { + assert( + (!isa<InvokeInst>(U) || + (getParentPad( + cast<InvokeInst>(U)->getUnwindDest()->getFirstNonPHI()) == + CatchPad)) && + "Expected useless pad"); if (isa<CatchSwitchInst>(U) || isa<CleanupPadInst>(U)) Worklist.push_back(cast<Instruction>(U)); + } + } } else { assert(isa<CleanupPadInst>(UselessPad)); - for (User *U : UselessPad->users()) + for (User *U : UselessPad->users()) { + assert(!isa<CleanupReturnInst>(U) && "Expected useless pad"); + assert((!isa<InvokeInst>(U) || + (getParentPad( + cast<InvokeInst>(U)->getUnwindDest()->getFirstNonPHI()) == + UselessPad)) && + "Expected useless pad"); if (isa<CatchSwitchInst>(U) || isa<CleanupPadInst>(U)) Worklist.push_back(cast<Instruction>(U)); + } } } @@ -463,37 +535,7 @@ static BasicBlock *HandleCallsInBlockInlinedThroughInvoke( #endif // NDEBUG } - // Convert this function call into an invoke instruction. First, split the - // basic block. - BasicBlock *Split = - BB->splitBasicBlock(CI->getIterator(), CI->getName() + ".noexc"); - - // Delete the unconditional branch inserted by splitBasicBlock - BB->getInstList().pop_back(); - - // Create the new invoke instruction. - SmallVector<Value*, 8> InvokeArgs(CI->arg_begin(), CI->arg_end()); - SmallVector<OperandBundleDef, 1> OpBundles; - - CI->getOperandBundlesAsDefs(OpBundles); - - // Note: we're round tripping operand bundles through memory here, and that - // can potentially be avoided with a cleverer API design that we do not have - // as of this time. - - InvokeInst *II = - InvokeInst::Create(CI->getCalledValue(), Split, UnwindEdge, InvokeArgs, - OpBundles, CI->getName(), BB); - II->setDebugLoc(CI->getDebugLoc()); - II->setCallingConv(CI->getCallingConv()); - II->setAttributes(CI->getAttributes()); - - // Make sure that anything using the call now uses the invoke! This also - // updates the CallGraph if present, because it uses a WeakVH. - CI->replaceAllUsesWith(II); - - // Delete the original call - Split->getInstList().pop_front(); + changeToInvokeAndSplitBasicBlock(CI, UnwindEdge); return BB; } return nullptr; @@ -718,7 +760,7 @@ static void PropagateParallelLoopAccessMetadata(CallSite CS, /// When inlining a function that contains noalias scope metadata, /// this metadata needs to be cloned so that the inlined blocks -/// have different "unqiue scopes" at every call site. Were this not done, then +/// have different "unique scopes" at every call site. Were this not done, then /// aliasing scopes from a function inlined into a caller multiple times could /// not be differentiated (and this would lead to miscompiles because the /// non-aliasing property communicated by the metadata could have @@ -1053,8 +1095,10 @@ static void AddAliasScopeMetadata(CallSite CS, ValueToValueMapTy &VMap, /// If the inlined function has non-byval align arguments, then /// add @llvm.assume-based alignment assumptions to preserve this information. static void AddAlignmentAssumptions(CallSite CS, InlineFunctionInfo &IFI) { - if (!PreserveAlignmentAssumptions) + if (!PreserveAlignmentAssumptions || !IFI.GetAssumptionCache) return; + + AssumptionCache *AC = &(*IFI.GetAssumptionCache)(*CS.getCaller()); auto &DL = CS.getCaller()->getParent()->getDataLayout(); // To avoid inserting redundant assumptions, we should check for assumptions @@ -1077,13 +1121,12 @@ static void AddAlignmentAssumptions(CallSite CS, InlineFunctionInfo &IFI) { // If we can already prove the asserted alignment in the context of the // caller, then don't bother inserting the assumption. Value *Arg = CS.getArgument(I->getArgNo()); - if (getKnownAlignment(Arg, DL, CS.getInstruction(), - &IFI.ACT->getAssumptionCache(*CS.getCaller()), - &DT) >= Align) + if (getKnownAlignment(Arg, DL, CS.getInstruction(), AC, &DT) >= Align) continue; - IRBuilder<>(CS.getInstruction()) - .CreateAlignmentAssumption(DL, Arg, Align); + CallInst *NewAssumption = IRBuilder<>(CS.getInstruction()) + .CreateAlignmentAssumption(DL, Arg, Align); + AC->registerAssumption(NewAssumption); } } } @@ -1194,12 +1237,13 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall, if (ByValAlignment <= 1) // 0 = unspecified, 1 = no particular alignment. return Arg; + AssumptionCache *AC = + IFI.GetAssumptionCache ? &(*IFI.GetAssumptionCache)(*Caller) : nullptr; const DataLayout &DL = Caller->getParent()->getDataLayout(); // If the pointer is already known to be sufficiently aligned, or if we can // round it up to a larger alignment, then we don't need a temporary. - if (getOrEnforceKnownAlignment(Arg, ByValAlignment, DL, TheCall, - &IFI.ACT->getAssumptionCache(*Caller)) >= + if (getOrEnforceKnownAlignment(Arg, ByValAlignment, DL, TheCall, AC) >= ByValAlignment) return Arg; @@ -1304,7 +1348,7 @@ static bool allocaWouldBeStaticInEntry(const AllocaInst *AI ) { /// Update inlined instructions' line numbers to /// to encode location where these instructions are inlined. static void fixupLineNumbers(Function *Fn, Function::iterator FI, - Instruction *TheCall) { + Instruction *TheCall, bool CalleeHasDebugInfo) { const DebugLoc &TheCallDL = TheCall->getDebugLoc(); if (!TheCallDL) return; @@ -1326,22 +1370,26 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, for (; FI != Fn->end(); ++FI) { for (BasicBlock::iterator BI = FI->begin(), BE = FI->end(); BI != BE; ++BI) { - DebugLoc DL = BI->getDebugLoc(); - if (!DL) { - // If the inlined instruction has no line number, make it look as if it - // originates from the call location. This is important for - // ((__always_inline__, __nodebug__)) functions which must use caller - // location for all instructions in their function body. - - // Don't update static allocas, as they may get moved later. - if (auto *AI = dyn_cast<AllocaInst>(BI)) - if (allocaWouldBeStaticInEntry(AI)) - continue; - - BI->setDebugLoc(TheCallDL); - } else { - BI->setDebugLoc(updateInlinedAtInfo(DL, InlinedAtNode, BI->getContext(), IANodes)); + if (DebugLoc DL = BI->getDebugLoc()) { + BI->setDebugLoc( + updateInlinedAtInfo(DL, InlinedAtNode, BI->getContext(), IANodes)); + continue; } + + if (CalleeHasDebugInfo) + continue; + + // If the inlined instruction has no line number, make it look as if it + // originates from the call location. This is important for + // ((__always_inline__, __nodebug__)) functions which must use caller + // location for all instructions in their function body. + + // Don't update static allocas, as they may get moved later. + if (auto *AI = dyn_cast<AllocaInst>(BI)) + if (allocaWouldBeStaticInEntry(AI)) + continue; + + BI->setDebugLoc(TheCallDL); } } } @@ -1597,8 +1645,11 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, if (IFI.CG) UpdateCallGraphAfterInlining(CS, FirstNewBlock, VMap, IFI); - // Update inlined instructions' line number information. - fixupLineNumbers(Caller, FirstNewBlock, TheCall); + // For 'nodebug' functions, the associated DISubprogram is always null. + // Conservatively avoid propagating the callsite debug location to + // instructions inlined from a function whose DISubprogram is not null. + fixupLineNumbers(Caller, FirstNewBlock, TheCall, + CalledFunc->getSubprogram() != nullptr); // Clone existing noalias metadata if necessary. CloneAliasScopeMetadata(CS, VMap); @@ -1609,10 +1660,15 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // Propagate llvm.mem.parallel_loop_access if necessary. PropagateParallelLoopAccessMetadata(CS, VMap); - // FIXME: We could register any cloned assumptions instead of clearing the - // whole function's cache. - if (IFI.ACT) - IFI.ACT->getAssumptionCache(*Caller).clear(); + // Register any cloned assumptions. + if (IFI.GetAssumptionCache) + for (BasicBlock &NewBlock : + make_range(FirstNewBlock->getIterator(), Caller->end())) + for (Instruction &I : NewBlock) { + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::assume) + (*IFI.GetAssumptionCache)(*Caller).registerAssumption(II); + } } // If there are any alloca instructions in the block that used to be the entry @@ -1708,6 +1764,9 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, IRBuilder<> builder(&FirstNewBlock->front()); for (unsigned ai = 0, ae = IFI.StaticAllocas.size(); ai != ae; ++ai) { AllocaInst *AI = IFI.StaticAllocas[ai]; + // Don't mark swifterror allocas. They can't have bitcast uses. + if (AI->isSwiftError()) + continue; // If the alloca is already scoped to something smaller than the whole // function then there's no need to add redundant, less accurate markers. @@ -1949,6 +2008,20 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, std::swap(Returns, NormalReturns); } + // Now that all of the transforms on the inlined code have taken place but + // before we splice the inlined code into the CFG and lose track of which + // blocks were actually inlined, collect the call sites. We only do this if + // call graph updates weren't requested, as those provide value handle based + // tracking of inlined call sites instead. + if (InlinedFunctionInfo.ContainsCalls && !IFI.CG) { + // Otherwise just collect the raw call sites that were inlined. + for (BasicBlock &NewBB : + make_range(FirstNewBlock->getIterator(), Caller->end())) + for (Instruction &I : NewBB) + if (auto CS = CallSite(&I)) + IFI.InlinedCallSites.push_back(CS); + } + // If we cloned in _exactly one_ basic block, and if that block ends in a // return instruction, we splice the body of the inlined callee directly into // the calling basic block. @@ -2130,9 +2203,10 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // the entries are the same or undef). If so, remove the PHI so it doesn't // block other optimizations. if (PHI) { + AssumptionCache *AC = + IFI.GetAssumptionCache ? &(*IFI.GetAssumptionCache)(*Caller) : nullptr; auto &DL = Caller->getParent()->getDataLayout(); - if (Value *V = SimplifyInstruction(PHI, DL, nullptr, nullptr, - &IFI.ACT->getAssumptionCache(*Caller))) { + if (Value *V = SimplifyInstruction(PHI, DL, nullptr, nullptr, AC)) { PHI->replaceAllUsesWith(V); PHI->eraseFromParent(); } diff --git a/contrib/llvm/lib/Transforms/Utils/LCSSA.cpp b/contrib/llvm/lib/Transforms/Utils/LCSSA.cpp index 0d5a25b..68c6b74 100644 --- a/contrib/llvm/lib/Transforms/Utils/LCSSA.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LCSSA.cpp @@ -51,10 +51,19 @@ using namespace llvm; STATISTIC(NumLCSSA, "Number of live out of a loop variables"); +#ifdef EXPENSIVE_CHECKS +static bool VerifyLoopLCSSA = true; +#else +static bool VerifyLoopLCSSA = false; +#endif +static cl::opt<bool,true> +VerifyLoopLCSSAFlag("verify-loop-lcssa", cl::location(VerifyLoopLCSSA), + cl::desc("Verify loop lcssa form (time consuming)")); + /// Return true if the specified block is in the list. static bool isExitBlock(BasicBlock *BB, const SmallVectorImpl<BasicBlock *> &ExitBlocks) { - return find(ExitBlocks, BB) != ExitBlocks.end(); + return is_contained(ExitBlocks, BB); } /// For every instruction from the worklist, check to see if it has any uses @@ -63,19 +72,25 @@ static bool isExitBlock(BasicBlock *BB, bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, DominatorTree &DT, LoopInfo &LI) { SmallVector<Use *, 16> UsesToRewrite; - SmallVector<BasicBlock *, 8> ExitBlocks; SmallSetVector<PHINode *, 16> PHIsToRemove; PredIteratorCache PredCache; bool Changed = false; + // Cache the Loop ExitBlocks across this loop. We expect to get a lot of + // instructions within the same loops, computing the exit blocks is + // expensive, and we're not mutating the loop structure. + SmallDenseMap<Loop*, SmallVector<BasicBlock *,1>> LoopExitBlocks; + while (!Worklist.empty()) { UsesToRewrite.clear(); - ExitBlocks.clear(); Instruction *I = Worklist.pop_back_val(); BasicBlock *InstBB = I->getParent(); Loop *L = LI.getLoopFor(InstBB); - L->getExitBlocks(ExitBlocks); + if (!LoopExitBlocks.count(L)) + L->getExitBlocks(LoopExitBlocks[L]); + assert(LoopExitBlocks.count(L)); + const SmallVectorImpl<BasicBlock *> &ExitBlocks = LoopExitBlocks[L]; if (ExitBlocks.empty()) continue; @@ -186,14 +201,14 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, // Otherwise, do full PHI insertion. SSAUpdate.RewriteUse(*UseToRewrite); + } - // SSAUpdater might have inserted phi-nodes inside other loops. We'll need - // to post-process them to keep LCSSA form. - for (PHINode *InsertedPN : InsertedPHIs) { - if (auto *OtherLoop = LI.getLoopFor(InsertedPN->getParent())) - if (!L->contains(OtherLoop)) - PostProcessPHIs.push_back(InsertedPN); - } + // SSAUpdater might have inserted phi-nodes inside other loops. We'll need + // to post-process them to keep LCSSA form. + for (PHINode *InsertedPN : InsertedPHIs) { + if (auto *OtherLoop = LI.getLoopFor(InsertedPN->getParent())) + if (!L->contains(OtherLoop)) + PostProcessPHIs.push_back(InsertedPN); } // Post process PHI instructions that were inserted into another disjoint @@ -229,7 +244,7 @@ blockDominatesAnExit(BasicBlock *BB, DominatorTree &DT, const SmallVectorImpl<BasicBlock *> &ExitBlocks) { DomTreeNode *DomNode = DT.getNode(BB); - return llvm::any_of(ExitBlocks, [&](BasicBlock * EB) { + return any_of(ExitBlocks, [&](BasicBlock *EB) { return DT.dominates(DomNode, DT.getNode(EB)); }); } @@ -315,6 +330,19 @@ struct LCSSAWrapperPass : public FunctionPass { ScalarEvolution *SE; bool runOnFunction(Function &F) override; + void verifyAnalysis() const override { + // This check is very expensive. On the loop intensive compiles it may cause + // up to 10x slowdown. Currently it's disabled by default. LPPassManager + // always does limited form of the LCSSA verification. Similar reasoning + // was used for the LoopInfo verifier. + if (VerifyLoopLCSSA) { + assert(all_of(*LI, + [&](Loop *L) { + return L->isRecursivelyLCSSAForm(*DT, *LI); + }) && + "LCSSA form is broken!"); + } + }; /// This transformation requires natural loop information & requires that /// loop preheaders be inserted into the CFG. It maintains both of these, @@ -330,6 +358,10 @@ struct LCSSAWrapperPass : public FunctionPass { AU.addPreserved<GlobalsAAWrapperPass>(); AU.addPreserved<ScalarEvolutionWrapperPass>(); AU.addPreserved<SCEVAAWrapperPass>(); + + // This is needed to perform LCSSA verification inside LPPassManager + AU.addRequired<LCSSAVerificationPass>(); + AU.addPreserved<LCSSAVerificationPass>(); } }; } @@ -339,6 +371,7 @@ INITIALIZE_PASS_BEGIN(LCSSAWrapperPass, "lcssa", "Loop-Closed SSA Form Pass", false, false) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LCSSAVerificationPass) INITIALIZE_PASS_END(LCSSAWrapperPass, "lcssa", "Loop-Closed SSA Form Pass", false, false) @@ -355,7 +388,7 @@ bool LCSSAWrapperPass::runOnFunction(Function &F) { return formLCSSAOnAllLoops(LI, *DT, SE); } -PreservedAnalyses LCSSAPass::run(Function &F, AnalysisManager<Function> &AM) { +PreservedAnalyses LCSSAPass::run(Function &F, FunctionAnalysisManager &AM) { auto &LI = AM.getResult<LoopAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto *SE = AM.getCachedResult<ScalarEvolutionAnalysis>(F); diff --git a/contrib/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/contrib/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp new file mode 100644 index 0000000..d97cd75 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Utils/LibCallsShrinkWrap.cpp @@ -0,0 +1,571 @@ +//===-- LibCallsShrinkWrap.cpp ----------------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass shrink-wraps a call to function if the result is not used. +// The call can set errno but is otherwise side effect free. For example: +// sqrt(val); +// is transformed to +// if (val < 0) +// sqrt(val); +// Even if the result of library call is not being used, the compiler cannot +// safely delete the call because the function can set errno on error +// conditions. +// Note in many functions, the error condition solely depends on the incoming +// parameter. In this optimization, we can generate the condition can lead to +// the errno to shrink-wrap the call. Since the chances of hitting the error +// condition is low, the runtime call is effectively eliminated. +// +// These partially dead calls are usually results of C++ abstraction penalty +// exposed by inlining. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/LibCallsShrinkWrap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +using namespace llvm; + +#define DEBUG_TYPE "libcalls-shrinkwrap" + +STATISTIC(NumWrappedOneCond, "Number of One-Condition Wrappers Inserted"); +STATISTIC(NumWrappedTwoCond, "Number of Two-Condition Wrappers Inserted"); + +static cl::opt<bool> LibCallsShrinkWrapDoDomainError( + "libcalls-shrinkwrap-domain-error", cl::init(true), cl::Hidden, + cl::desc("Perform shrink-wrap on lib calls with domain errors")); +static cl::opt<bool> LibCallsShrinkWrapDoRangeError( + "libcalls-shrinkwrap-range-error", cl::init(true), cl::Hidden, + cl::desc("Perform shrink-wrap on lib calls with range errors")); +static cl::opt<bool> LibCallsShrinkWrapDoPoleError( + "libcalls-shrinkwrap-pole-error", cl::init(true), cl::Hidden, + cl::desc("Perform shrink-wrap on lib calls with pole errors")); + +namespace { +class LibCallsShrinkWrapLegacyPass : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + explicit LibCallsShrinkWrapLegacyPass() : FunctionPass(ID) { + initializeLibCallsShrinkWrapLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + void getAnalysisUsage(AnalysisUsage &AU) const override; + bool runOnFunction(Function &F) override; +}; +} + +char LibCallsShrinkWrapLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LibCallsShrinkWrapLegacyPass, "libcalls-shrinkwrap", + "Conditionally eliminate dead library calls", false, + false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(LibCallsShrinkWrapLegacyPass, "libcalls-shrinkwrap", + "Conditionally eliminate dead library calls", false, false) + +namespace { +class LibCallsShrinkWrap : public InstVisitor<LibCallsShrinkWrap> { +public: + LibCallsShrinkWrap(const TargetLibraryInfo &TLI) : TLI(TLI), Changed(false){}; + bool isChanged() const { return Changed; } + void visitCallInst(CallInst &CI) { checkCandidate(CI); } + void perform() { + for (auto &CI : WorkList) { + DEBUG(dbgs() << "CDCE calls: " << CI->getCalledFunction()->getName() + << "\n"); + if (perform(CI)) { + Changed = true; + DEBUG(dbgs() << "Transformed\n"); + } + } + } + +private: + bool perform(CallInst *CI); + void checkCandidate(CallInst &CI); + void shrinkWrapCI(CallInst *CI, Value *Cond); + bool performCallDomainErrorOnly(CallInst *CI, const LibFunc::Func &Func); + bool performCallErrors(CallInst *CI, const LibFunc::Func &Func); + bool performCallRangeErrorOnly(CallInst *CI, const LibFunc::Func &Func); + Value *generateOneRangeCond(CallInst *CI, const LibFunc::Func &Func); + Value *generateTwoRangeCond(CallInst *CI, const LibFunc::Func &Func); + Value *generateCondForPow(CallInst *CI, const LibFunc::Func &Func); + + // Create an OR of two conditions. + Value *createOrCond(CallInst *CI, CmpInst::Predicate Cmp, float Val, + CmpInst::Predicate Cmp2, float Val2) { + IRBuilder<> BBBuilder(CI); + Value *Arg = CI->getArgOperand(0); + auto Cond2 = createCond(BBBuilder, Arg, Cmp2, Val2); + auto Cond1 = createCond(BBBuilder, Arg, Cmp, Val); + return BBBuilder.CreateOr(Cond1, Cond2); + } + + // Create a single condition using IRBuilder. + Value *createCond(IRBuilder<> &BBBuilder, Value *Arg, CmpInst::Predicate Cmp, + float Val) { + Constant *V = ConstantFP::get(BBBuilder.getContext(), APFloat(Val)); + if (!Arg->getType()->isFloatTy()) + V = ConstantExpr::getFPExtend(V, Arg->getType()); + return BBBuilder.CreateFCmp(Cmp, Arg, V); + } + + // Create a single condition. + Value *createCond(CallInst *CI, CmpInst::Predicate Cmp, float Val) { + IRBuilder<> BBBuilder(CI); + Value *Arg = CI->getArgOperand(0); + return createCond(BBBuilder, Arg, Cmp, Val); + } + + const TargetLibraryInfo &TLI; + SmallVector<CallInst *, 16> WorkList; + bool Changed; +}; +} // end anonymous namespace + +// Perform the transformation to calls with errno set by domain error. +bool LibCallsShrinkWrap::performCallDomainErrorOnly(CallInst *CI, + const LibFunc::Func &Func) { + Value *Cond = nullptr; + + switch (Func) { + case LibFunc::acos: // DomainError: (x < -1 || x > 1) + case LibFunc::acosf: // Same as acos + case LibFunc::acosl: // Same as acos + case LibFunc::asin: // DomainError: (x < -1 || x > 1) + case LibFunc::asinf: // Same as asin + case LibFunc::asinl: // Same as asin + { + ++NumWrappedTwoCond; + Cond = createOrCond(CI, CmpInst::FCMP_OLT, -1.0f, CmpInst::FCMP_OGT, 1.0f); + break; + } + case LibFunc::cos: // DomainError: (x == +inf || x == -inf) + case LibFunc::cosf: // Same as cos + case LibFunc::cosl: // Same as cos + case LibFunc::sin: // DomainError: (x == +inf || x == -inf) + case LibFunc::sinf: // Same as sin + case LibFunc::sinl: // Same as sin + { + ++NumWrappedTwoCond; + Cond = createOrCond(CI, CmpInst::FCMP_OEQ, INFINITY, CmpInst::FCMP_OEQ, + -INFINITY); + break; + } + case LibFunc::acosh: // DomainError: (x < 1) + case LibFunc::acoshf: // Same as acosh + case LibFunc::acoshl: // Same as acosh + { + ++NumWrappedOneCond; + Cond = createCond(CI, CmpInst::FCMP_OLT, 1.0f); + break; + } + case LibFunc::sqrt: // DomainError: (x < 0) + case LibFunc::sqrtf: // Same as sqrt + case LibFunc::sqrtl: // Same as sqrt + { + ++NumWrappedOneCond; + Cond = createCond(CI, CmpInst::FCMP_OLT, 0.0f); + break; + } + default: + return false; + } + shrinkWrapCI(CI, Cond); + return true; +} + +// Perform the transformation to calls with errno set by range error. +bool LibCallsShrinkWrap::performCallRangeErrorOnly(CallInst *CI, + const LibFunc::Func &Func) { + Value *Cond = nullptr; + + switch (Func) { + case LibFunc::cosh: + case LibFunc::coshf: + case LibFunc::coshl: + case LibFunc::exp: + case LibFunc::expf: + case LibFunc::expl: + case LibFunc::exp10: + case LibFunc::exp10f: + case LibFunc::exp10l: + case LibFunc::exp2: + case LibFunc::exp2f: + case LibFunc::exp2l: + case LibFunc::sinh: + case LibFunc::sinhf: + case LibFunc::sinhl: { + Cond = generateTwoRangeCond(CI, Func); + break; + } + case LibFunc::expm1: // RangeError: (709, inf) + case LibFunc::expm1f: // RangeError: (88, inf) + case LibFunc::expm1l: // RangeError: (11356, inf) + { + Cond = generateOneRangeCond(CI, Func); + break; + } + default: + return false; + } + shrinkWrapCI(CI, Cond); + return true; +} + +// Perform the transformation to calls with errno set by combination of errors. +bool LibCallsShrinkWrap::performCallErrors(CallInst *CI, + const LibFunc::Func &Func) { + Value *Cond = nullptr; + + switch (Func) { + case LibFunc::atanh: // DomainError: (x < -1 || x > 1) + // PoleError: (x == -1 || x == 1) + // Overall Cond: (x <= -1 || x >= 1) + case LibFunc::atanhf: // Same as atanh + case LibFunc::atanhl: // Same as atanh + { + if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError) + return false; + ++NumWrappedTwoCond; + Cond = createOrCond(CI, CmpInst::FCMP_OLE, -1.0f, CmpInst::FCMP_OGE, 1.0f); + break; + } + case LibFunc::log: // DomainError: (x < 0) + // PoleError: (x == 0) + // Overall Cond: (x <= 0) + case LibFunc::logf: // Same as log + case LibFunc::logl: // Same as log + case LibFunc::log10: // Same as log + case LibFunc::log10f: // Same as log + case LibFunc::log10l: // Same as log + case LibFunc::log2: // Same as log + case LibFunc::log2f: // Same as log + case LibFunc::log2l: // Same as log + case LibFunc::logb: // Same as log + case LibFunc::logbf: // Same as log + case LibFunc::logbl: // Same as log + { + if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError) + return false; + ++NumWrappedOneCond; + Cond = createCond(CI, CmpInst::FCMP_OLE, 0.0f); + break; + } + case LibFunc::log1p: // DomainError: (x < -1) + // PoleError: (x == -1) + // Overall Cond: (x <= -1) + case LibFunc::log1pf: // Same as log1p + case LibFunc::log1pl: // Same as log1p + { + if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError) + return false; + ++NumWrappedOneCond; + Cond = createCond(CI, CmpInst::FCMP_OLE, -1.0f); + break; + } + case LibFunc::pow: // DomainError: x < 0 and y is noninteger + // PoleError: x == 0 and y < 0 + // RangeError: overflow or underflow + case LibFunc::powf: + case LibFunc::powl: { + if (!LibCallsShrinkWrapDoDomainError || !LibCallsShrinkWrapDoPoleError || + !LibCallsShrinkWrapDoRangeError) + return false; + Cond = generateCondForPow(CI, Func); + if (Cond == nullptr) + return false; + break; + } + default: + return false; + } + assert(Cond && "performCallErrors should not see an empty condition"); + shrinkWrapCI(CI, Cond); + return true; +} + +// Checks if CI is a candidate for shrinkwrapping and put it into work list if +// true. +void LibCallsShrinkWrap::checkCandidate(CallInst &CI) { + if (CI.isNoBuiltin()) + return; + // A possible improvement is to handle the calls with the return value being + // used. If there is API for fast libcall implementation without setting + // errno, we can use the same framework to direct/wrap the call to the fast + // API in the error free path, and leave the original call in the slow path. + if (!CI.use_empty()) + return; + + LibFunc::Func Func; + Function *Callee = CI.getCalledFunction(); + if (!Callee) + return; + if (!TLI.getLibFunc(*Callee, Func) || !TLI.has(Func)) + return; + + if (CI.getNumArgOperands() == 0) + return; + // TODO: Handle long double in other formats. + Type *ArgType = CI.getArgOperand(0)->getType(); + if (!(ArgType->isFloatTy() || ArgType->isDoubleTy() || + ArgType->isX86_FP80Ty())) + return; + + WorkList.push_back(&CI); +} + +// Generate the upper bound condition for RangeError. +Value *LibCallsShrinkWrap::generateOneRangeCond(CallInst *CI, + const LibFunc::Func &Func) { + float UpperBound; + switch (Func) { + case LibFunc::expm1: // RangeError: (709, inf) + UpperBound = 709.0f; + break; + case LibFunc::expm1f: // RangeError: (88, inf) + UpperBound = 88.0f; + break; + case LibFunc::expm1l: // RangeError: (11356, inf) + UpperBound = 11356.0f; + break; + default: + llvm_unreachable("Should be reach here"); + } + + ++NumWrappedOneCond; + return createCond(CI, CmpInst::FCMP_OGT, UpperBound); +} + +// Generate the lower and upper bound condition for RangeError. +Value *LibCallsShrinkWrap::generateTwoRangeCond(CallInst *CI, + const LibFunc::Func &Func) { + float UpperBound, LowerBound; + switch (Func) { + case LibFunc::cosh: // RangeError: (x < -710 || x > 710) + case LibFunc::sinh: // Same as cosh + LowerBound = -710.0f; + UpperBound = 710.0f; + break; + case LibFunc::coshf: // RangeError: (x < -89 || x > 89) + case LibFunc::sinhf: // Same as coshf + LowerBound = -89.0f; + UpperBound = 89.0f; + break; + case LibFunc::coshl: // RangeError: (x < -11357 || x > 11357) + case LibFunc::sinhl: // Same as coshl + LowerBound = -11357.0f; + UpperBound = 11357.0f; + break; + case LibFunc::exp: // RangeError: (x < -745 || x > 709) + LowerBound = -745.0f; + UpperBound = 709.0f; + break; + case LibFunc::expf: // RangeError: (x < -103 || x > 88) + LowerBound = -103.0f; + UpperBound = 88.0f; + break; + case LibFunc::expl: // RangeError: (x < -11399 || x > 11356) + LowerBound = -11399.0f; + UpperBound = 11356.0f; + break; + case LibFunc::exp10: // RangeError: (x < -323 || x > 308) + LowerBound = -323.0f; + UpperBound = 308.0f; + break; + case LibFunc::exp10f: // RangeError: (x < -45 || x > 38) + LowerBound = -45.0f; + UpperBound = 38.0f; + break; + case LibFunc::exp10l: // RangeError: (x < -4950 || x > 4932) + LowerBound = -4950.0f; + UpperBound = 4932.0f; + break; + case LibFunc::exp2: // RangeError: (x < -1074 || x > 1023) + LowerBound = -1074.0f; + UpperBound = 1023.0f; + break; + case LibFunc::exp2f: // RangeError: (x < -149 || x > 127) + LowerBound = -149.0f; + UpperBound = 127.0f; + break; + case LibFunc::exp2l: // RangeError: (x < -16445 || x > 11383) + LowerBound = -16445.0f; + UpperBound = 11383.0f; + break; + default: + llvm_unreachable("Should be reach here"); + } + + ++NumWrappedTwoCond; + return createOrCond(CI, CmpInst::FCMP_OGT, UpperBound, CmpInst::FCMP_OLT, + LowerBound); +} + +// For pow(x,y), We only handle the following cases: +// (1) x is a constant && (x >= 1) && (x < MaxUInt8) +// Cond is: (y > 127) +// (2) x is a value coming from an integer type. +// (2.1) if x's bit_size == 8 +// Cond: (x <= 0 || y > 128) +// (2.2) if x's bit_size is 16 +// Cond: (x <= 0 || y > 64) +// (2.3) if x's bit_size is 32 +// Cond: (x <= 0 || y > 32) +// Support for powl(x,y) and powf(x,y) are TBD. +// +// Note that condition can be more conservative than the actual condition +// (i.e. we might invoke the calls that will not set the errno.). +// +Value *LibCallsShrinkWrap::generateCondForPow(CallInst *CI, + const LibFunc::Func &Func) { + // FIXME: LibFunc::powf and powl TBD. + if (Func != LibFunc::pow) { + DEBUG(dbgs() << "Not handled powf() and powl()\n"); + return nullptr; + } + + Value *Base = CI->getArgOperand(0); + Value *Exp = CI->getArgOperand(1); + IRBuilder<> BBBuilder(CI); + + // Constant Base case. + if (ConstantFP *CF = dyn_cast<ConstantFP>(Base)) { + double D = CF->getValueAPF().convertToDouble(); + if (D < 1.0f || D > APInt::getMaxValue(8).getZExtValue()) { + DEBUG(dbgs() << "Not handled pow(): constant base out of range\n"); + return nullptr; + } + + ++NumWrappedOneCond; + Constant *V = ConstantFP::get(CI->getContext(), APFloat(127.0f)); + if (!Exp->getType()->isFloatTy()) + V = ConstantExpr::getFPExtend(V, Exp->getType()); + return BBBuilder.CreateFCmp(CmpInst::FCMP_OGT, Exp, V); + } + + // If the Base value coming from an integer type. + Instruction *I = dyn_cast<Instruction>(Base); + if (!I) { + DEBUG(dbgs() << "Not handled pow(): FP type base\n"); + return nullptr; + } + unsigned Opcode = I->getOpcode(); + if (Opcode == Instruction::UIToFP || Opcode == Instruction::SIToFP) { + unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits(); + float UpperV = 0.0f; + if (BW == 8) + UpperV = 128.0f; + else if (BW == 16) + UpperV = 64.0f; + else if (BW == 32) + UpperV = 32.0f; + else { + DEBUG(dbgs() << "Not handled pow(): type too wide\n"); + return nullptr; + } + + ++NumWrappedTwoCond; + Constant *V = ConstantFP::get(CI->getContext(), APFloat(UpperV)); + Constant *V0 = ConstantFP::get(CI->getContext(), APFloat(0.0f)); + if (!Exp->getType()->isFloatTy()) + V = ConstantExpr::getFPExtend(V, Exp->getType()); + if (!Base->getType()->isFloatTy()) + V0 = ConstantExpr::getFPExtend(V0, Exp->getType()); + + Value *Cond = BBBuilder.CreateFCmp(CmpInst::FCMP_OGT, Exp, V); + Value *Cond0 = BBBuilder.CreateFCmp(CmpInst::FCMP_OLE, Base, V0); + return BBBuilder.CreateOr(Cond0, Cond); + } + DEBUG(dbgs() << "Not handled pow(): base not from integer convert\n"); + return nullptr; +} + +// Wrap conditions that can potentially generate errno to the library call. +void LibCallsShrinkWrap::shrinkWrapCI(CallInst *CI, Value *Cond) { + assert(Cond != nullptr && "hrinkWrapCI is not expecting an empty call inst"); + MDNode *BranchWeights = + MDBuilder(CI->getContext()).createBranchWeights(1, 2000); + TerminatorInst *NewInst = + SplitBlockAndInsertIfThen(Cond, CI, false, BranchWeights); + BasicBlock *CallBB = NewInst->getParent(); + CallBB->setName("cdce.call"); + CallBB->getSingleSuccessor()->setName("cdce.end"); + CI->removeFromParent(); + CallBB->getInstList().insert(CallBB->getFirstInsertionPt(), CI); + DEBUG(dbgs() << "== Basic Block After =="); + DEBUG(dbgs() << *CallBB->getSinglePredecessor() << *CallBB + << *CallBB->getSingleSuccessor() << "\n"); +} + +// Perform the transformation to a single candidate. +bool LibCallsShrinkWrap::perform(CallInst *CI) { + LibFunc::Func Func; + Function *Callee = CI->getCalledFunction(); + assert(Callee && "perform() should apply to a non-empty callee"); + TLI.getLibFunc(*Callee, Func); + assert(Func && "perform() is not expecting an empty function"); + + if (LibCallsShrinkWrapDoDomainError && performCallDomainErrorOnly(CI, Func)) + return true; + + if (LibCallsShrinkWrapDoRangeError && performCallRangeErrorOnly(CI, Func)) + return true; + + return performCallErrors(CI, Func); +} + +void LibCallsShrinkWrapLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); +} + +static bool runImpl(Function &F, const TargetLibraryInfo &TLI) { + if (F.hasFnAttribute(Attribute::OptimizeForSize)) + return false; + LibCallsShrinkWrap CCDCE(TLI); + CCDCE.visit(F); + CCDCE.perform(); + return CCDCE.isChanged(); +} + +bool LibCallsShrinkWrapLegacyPass::runOnFunction(Function &F) { + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + return runImpl(F, TLI); +} + +namespace llvm { +char &LibCallsShrinkWrapPassID = LibCallsShrinkWrapLegacyPass::ID; + +// Public interface to LibCallsShrinkWrap pass. +FunctionPass *createLibCallsShrinkWrapPass() { + return new LibCallsShrinkWrapLegacyPass(); +} + +PreservedAnalyses LibCallsShrinkWrapPass::run(Function &F, + FunctionAnalysisManager &FAM) { + auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F); + bool Changed = runImpl(F, TLI); + if (!Changed) + return PreservedAnalyses::all(); + auto PA = PreservedAnalyses(); + PA.preserve<GlobalsAA>(); + return PA; +} +} diff --git a/contrib/llvm/lib/Transforms/Utils/Local.cpp b/contrib/llvm/lib/Transforms/Utils/Local.cpp index f1838d8..6e4174a 100644 --- a/contrib/llvm/lib/Transforms/Utils/Local.cpp +++ b/contrib/llvm/lib/Transforms/Utils/Local.cpp @@ -340,6 +340,10 @@ bool llvm::isInstructionTriviallyDead(Instruction *I, if (Constant *C = dyn_cast<Constant>(CI->getArgOperand(0))) return C->isNullValue() || isa<UndefValue>(C); + if (CallSite CS = CallSite(I)) + if (isMathLibCallNoop(CS, TLI)) + return true; + return false; } @@ -886,6 +890,17 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB) { } } + // If the unconditional branch we replaced contains llvm.loop metadata, we + // add the metadata to the branch instructions in the predecessors. + unsigned LoopMDKind = BB->getContext().getMDKindID("llvm.loop"); + Instruction *TI = BB->getTerminator(); + if (TI) + if (MDNode *LoopMD = TI->getMetadata(LoopMDKind)) + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { + BasicBlock *Pred = *PI; + Pred->getTerminator()->setMetadata(LoopMDKind, LoopMD); + } + // Everything that jumped to BB now goes to Succ. BB->replaceAllUsesWith(Succ); if (!Succ->hasName()) Succ->takeName(BB); @@ -1001,10 +1016,6 @@ static unsigned enforceKnownAlignment(Value *V, unsigned Align, return Align; } -/// getOrEnforceKnownAlignment - If the specified pointer has an alignment that -/// we can determine, return it, otherwise return 0. If PrefAlign is specified, -/// and it is more than the alignment of the ultimate object, see if we can -/// increase the alignment of the ultimate object, making this check succeed. unsigned llvm::getOrEnforceKnownAlignment(Value *V, unsigned PrefAlign, const DataLayout &DL, const Instruction *CxtI, @@ -1057,9 +1068,27 @@ static bool LdStHasDebugValue(DILocalVariable *DIVar, DIExpression *DIExpr, return false; } +/// See if there is a dbg.value intrinsic for DIVar for the PHI node. +static bool PhiHasDebugValue(DILocalVariable *DIVar, + DIExpression *DIExpr, + PHINode *APN) { + // Since we can't guarantee that the original dbg.declare instrinsic + // is removed by LowerDbgDeclare(), we need to make sure that we are + // not inserting the same dbg.value intrinsic over and over. + DbgValueList DbgValues; + FindAllocaDbgValues(DbgValues, APN); + for (auto DVI : DbgValues) { + assert (DVI->getValue() == APN); + assert (DVI->getOffset() == 0); + if ((DVI->getVariable() == DIVar) && (DVI->getExpression() == DIExpr)) + return true; + } + return false; +} + /// Inserts a llvm.dbg.value intrinsic before a store to an alloca'd value /// that has an associated llvm.dbg.decl intrinsic. -bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, +void llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, StoreInst *SI, DIBuilder &Builder) { auto *DIVar = DDI->getVariable(); auto *DIExpr = DDI->getExpression(); @@ -1073,26 +1102,27 @@ bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, if (SExtInst *SExt = dyn_cast<SExtInst>(SI->getOperand(0))) ExtendedArg = dyn_cast<Argument>(SExt->getOperand(0)); if (ExtendedArg) { - // We're now only describing a subset of the variable. The piece we're + // We're now only describing a subset of the variable. The fragment we're // describing will always be smaller than the variable size, because // VariableSize == Size of Alloca described by DDI. Since SI stores // to the alloca described by DDI, if it's first operand is an extend, // we're guaranteed that before extension, the value was narrower than // the size of the alloca, hence the size of the described variable. SmallVector<uint64_t, 3> Ops; - unsigned PieceOffset = 0; - // If this already is a bit piece, we drop the bit piece from the expression - // and record the offset. - if (DIExpr->isBitPiece()) { + unsigned FragmentOffset = 0; + // If this already is a bit fragment, we drop the bit fragment from the + // expression and record the offset. + auto Fragment = DIExpr->getFragmentInfo(); + if (Fragment) { Ops.append(DIExpr->elements_begin(), DIExpr->elements_end()-3); - PieceOffset = DIExpr->getBitPieceOffset(); + FragmentOffset = Fragment->OffsetInBits; } else { Ops.append(DIExpr->elements_begin(), DIExpr->elements_end()); } - Ops.push_back(dwarf::DW_OP_bit_piece); - Ops.push_back(PieceOffset); // Offset + Ops.push_back(dwarf::DW_OP_LLVM_fragment); + Ops.push_back(FragmentOffset); const DataLayout &DL = DDI->getModule()->getDataLayout(); - Ops.push_back(DL.getTypeSizeInBits(ExtendedArg->getType())); // Size + Ops.push_back(DL.getTypeSizeInBits(ExtendedArg->getType())); auto NewDIExpr = Builder.createExpression(Ops); if (!LdStHasDebugValue(DIVar, NewDIExpr, SI)) Builder.insertDbgValueIntrinsic(ExtendedArg, 0, DIVar, NewDIExpr, @@ -1100,19 +1130,18 @@ bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, } else if (!LdStHasDebugValue(DIVar, DIExpr, SI)) Builder.insertDbgValueIntrinsic(SI->getOperand(0), 0, DIVar, DIExpr, DDI->getDebugLoc(), SI); - return true; } /// Inserts a llvm.dbg.value intrinsic before a load of an alloca'd value /// that has an associated llvm.dbg.decl intrinsic. -bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, +void llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, LoadInst *LI, DIBuilder &Builder) { auto *DIVar = DDI->getVariable(); auto *DIExpr = DDI->getExpression(); assert(DIVar && "Missing variable"); if (LdStHasDebugValue(DIVar, DIExpr, LI)) - return true; + return; // We are now tracking the loaded value instead of the address. In the // future if multi-location support is added to the IR, it might be @@ -1121,7 +1150,28 @@ bool llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, Instruction *DbgValue = Builder.insertDbgValueIntrinsic( LI, 0, DIVar, DIExpr, DDI->getDebugLoc(), (Instruction *)nullptr); DbgValue->insertAfter(LI); - return true; +} + +/// Inserts a llvm.dbg.value intrinsic after a phi +/// that has an associated llvm.dbg.decl intrinsic. +void llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, + PHINode *APN, DIBuilder &Builder) { + auto *DIVar = DDI->getVariable(); + auto *DIExpr = DDI->getExpression(); + assert(DIVar && "Missing variable"); + + if (PhiHasDebugValue(DIVar, DIExpr, APN)) + return; + + BasicBlock *BB = APN->getParent(); + auto InsertionPt = BB->getFirstInsertionPt(); + + // The block may be a catchswitch block, which does not have a valid + // insertion point. + // FIXME: Insert dbg.value markers in the successors when appropriate. + if (InsertionPt != BB->end()) + Builder.insertDbgValueIntrinsic(APN, 0, DIVar, DIExpr, DDI->getDebugLoc(), + &*InsertionPt); } /// Determine whether this alloca is either a VLA or an array. @@ -1191,6 +1241,16 @@ DbgDeclareInst *llvm::FindAllocaDbgDeclare(Value *V) { return nullptr; } +/// FindAllocaDbgValues - Finds the llvm.dbg.value intrinsics describing the +/// alloca 'V', if any. +void llvm::FindAllocaDbgValues(DbgValueList &DbgValues, Value *V) { + if (auto *L = LocalAsMetadata::getIfExists(V)) + if (auto *MDV = MetadataAsValue::getIfExists(V->getContext(), L)) + for (User *U : MDV->users()) + if (DbgValueInst *DVI = dyn_cast<DbgValueInst>(U)) + DbgValues.push_back(DVI); +} + static void DIExprAddDeref(SmallVectorImpl<uint64_t> &Expr) { Expr.push_back(dwarf::DW_OP_deref); } @@ -1310,12 +1370,13 @@ unsigned llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) { return NumDeadInst; } -unsigned llvm::changeToUnreachable(Instruction *I, bool UseLLVMTrap) { +unsigned llvm::changeToUnreachable(Instruction *I, bool UseLLVMTrap, + bool PreserveLCSSA) { BasicBlock *BB = I->getParent(); // Loop over all of the successors, removing BB's entry from any PHI // nodes. for (BasicBlock *Successor : successors(BB)) - Successor->removePredecessor(BB); + Successor->removePredecessor(BB, PreserveLCSSA); // Insert a call to llvm.trap right before this. This turns the undefined // behavior into a hard fail instead of falling through into random code. @@ -1360,6 +1421,43 @@ static void changeToCall(InvokeInst *II) { II->eraseFromParent(); } +BasicBlock *llvm::changeToInvokeAndSplitBasicBlock(CallInst *CI, + BasicBlock *UnwindEdge) { + BasicBlock *BB = CI->getParent(); + + // Convert this function call into an invoke instruction. First, split the + // basic block. + BasicBlock *Split = + BB->splitBasicBlock(CI->getIterator(), CI->getName() + ".noexc"); + + // Delete the unconditional branch inserted by splitBasicBlock + BB->getInstList().pop_back(); + + // Create the new invoke instruction. + SmallVector<Value *, 8> InvokeArgs(CI->arg_begin(), CI->arg_end()); + SmallVector<OperandBundleDef, 1> OpBundles; + + CI->getOperandBundlesAsDefs(OpBundles); + + // Note: we're round tripping operand bundles through memory here, and that + // can potentially be avoided with a cleverer API design that we do not have + // as of this time. + + InvokeInst *II = InvokeInst::Create(CI->getCalledValue(), Split, UnwindEdge, + InvokeArgs, OpBundles, CI->getName(), BB); + II->setDebugLoc(CI->getDebugLoc()); + II->setCallingConv(CI->getCallingConv()); + II->setAttributes(CI->getAttributes()); + + // Make sure that anything using the call now uses the invoke! This also + // updates the CallGraph if present, because it uses a WeakVH. + CI->replaceAllUsesWith(II); + + // Delete the original call + Split->getInstList().pop_front(); + return Split; +} + static bool markAliveBlocks(Function &F, SmallPtrSetImpl<BasicBlock*> &Reachable) { @@ -1586,10 +1684,10 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; K->dropUnknownNonDebugMetadata(KnownIDs); K->getAllMetadataOtherThanDebugLoc(Metadata); - for (unsigned i = 0, n = Metadata.size(); i < n; ++i) { - unsigned Kind = Metadata[i].first; + for (const auto &MD : Metadata) { + unsigned Kind = MD.first; MDNode *JMD = J->getMetadata(Kind); - MDNode *KMD = Metadata[i].second; + MDNode *KMD = MD.second; switch (Kind) { default: @@ -1646,6 +1744,17 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, K->setMetadata(LLVMContext::MD_invariant_group, JMD); } +void llvm::combineMetadataForCSE(Instruction *K, const Instruction *J) { + unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, LLVMContext::MD_range, + LLVMContext::MD_invariant_load, LLVMContext::MD_nonnull, + LLVMContext::MD_invariant_group, LLVMContext::MD_align, + LLVMContext::MD_dereferenceable, + LLVMContext::MD_dereferenceable_or_null}; + combineMetadata(K, J, KnownIDs); +} + unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To, DominatorTree &DT, const BasicBlockEdge &Root) { @@ -1703,6 +1812,7 @@ bool llvm::callsGCLeafFunction(ImmutableCallSite CS) { return false; } +namespace { /// A potential constituent of a bitreverse or bswap expression. See /// collectBitParts for a fuller explanation. struct BitPart { @@ -1718,6 +1828,7 @@ struct BitPart { enum { Unset = -1 }; }; +} // end anonymous namespace /// Analyze the specified subexpression and see if it is capable of providing /// pieces of a bswap or bitreverse. The subexpression provides a potential @@ -1954,23 +2065,12 @@ bool llvm::recognizeBSwapOrBitReverseIdiom( // in ASan/MSan/TSan/DFSan, and thus make us miss some memory accesses, // we mark affected calls as NoBuiltin, which will disable optimization // in CodeGen. -void llvm::maybeMarkSanitizerLibraryCallNoBuiltin(CallInst *CI, - const TargetLibraryInfo *TLI) { +void llvm::maybeMarkSanitizerLibraryCallNoBuiltin( + CallInst *CI, const TargetLibraryInfo *TLI) { Function *F = CI->getCalledFunction(); LibFunc::Func Func; - if (!F || F->hasLocalLinkage() || !F->hasName() || - !TLI->getLibFunc(F->getName(), Func)) - return; - switch (Func) { - default: break; - case LibFunc::memcmp: - case LibFunc::memchr: - case LibFunc::strcpy: - case LibFunc::stpcpy: - case LibFunc::strcmp: - case LibFunc::strlen: - case LibFunc::strnlen: - CI->addAttribute(AttributeSet::FunctionIndex, Attribute::NoBuiltin); - break; - } + if (F && !F->hasLocalLinkage() && F->hasName() && + TLI->getLibFunc(F->getName(), Func) && TLI->hasOptimizedCodeGen(Func) && + !F->doesNotAccessMemory()) + CI->addAttribute(AttributeSet::FunctionIndex, Attribute::NoBuiltin); } diff --git a/contrib/llvm/lib/Transforms/Utils/LoopSimplify.cpp b/contrib/llvm/lib/Transforms/Utils/LoopSimplify.cpp index 2846e8f..00cda2a 100644 --- a/contrib/llvm/lib/Transforms/Utils/LoopSimplify.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LoopSimplify.cpp @@ -361,25 +361,12 @@ static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, // Fix LCSSA form for L. Some values, which previously were only used inside // L, can now be used in NewOuter loop. We need to insert phi-nodes for them // in corresponding exit blocks. + // We don't need to form LCSSA recursively, because there cannot be uses + // inside a newly created loop of defs from inner loops as those would + // already be a use of an LCSSA phi node. + formLCSSA(*L, *DT, LI, SE); - // Go through all instructions in OuterLoopBlocks and check if they are - // using operands from the inner loop. In this case we'll need to fix LCSSA - // for these instructions. - SmallSetVector<Instruction *, 8> WorklistSet; - for (BasicBlock *OuterBB: OuterLoopBlocks) { - for (Instruction &I : *OuterBB) { - for (Value *Op : I.operands()) { - Instruction *OpI = dyn_cast<Instruction>(Op); - if (!OpI || !L->contains(OpI)) - continue; - WorklistSet.insert(OpI); - } - } - } - SmallVector<Instruction *, 8> Worklist(WorklistSet.begin(), - WorklistSet.end()); - formLCSSAForInstructions(Worklist, *DT, *LI); - assert(NewOuter->isRecursivelyLCSSAForm(*DT) && + assert(NewOuter->isRecursivelyLCSSAForm(*DT, *LI) && "LCSSA is broken after separating nested loops!"); } @@ -483,13 +470,21 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, } // Now that all of the PHI nodes have been inserted and adjusted, modify the - // backedge blocks to just to the BEBlock instead of the header. + // backedge blocks to jump to the BEBlock instead of the header. + // If one of the backedges has llvm.loop metadata attached, we remove + // it from the backedge and add it to BEBlock. + unsigned LoopMDKind = BEBlock->getContext().getMDKindID("llvm.loop"); + MDNode *LoopMD = nullptr; for (unsigned i = 0, e = BackedgeBlocks.size(); i != e; ++i) { TerminatorInst *TI = BackedgeBlocks[i]->getTerminator(); + if (!LoopMD) + LoopMD = TI->getMetadata(LoopMDKind); + TI->setMetadata(LoopMDKind, nullptr); for (unsigned Op = 0, e = TI->getNumSuccessors(); Op != e; ++Op) if (TI->getSuccessor(Op) == Header) TI->setSuccessor(Op, BEBlock); } + BEBlock->getTerminator()->setMetadata(LoopMDKind, LoopMD); //===--- Update all analyses which we must preserve now -----------------===// @@ -535,7 +530,7 @@ ReprocessLoop: // Zap the dead pred's terminator and replace it with unreachable. TerminatorInst *TI = P->getTerminator(); - changeToUnreachable(TI, /*UseLLVMTrap=*/false); + changeToUnreachable(TI, /*UseLLVMTrap=*/false, PreserveLCSSA); Changed = true; } } @@ -635,8 +630,10 @@ ReprocessLoop: (PN = dyn_cast<PHINode>(I++)); ) if (Value *V = SimplifyInstruction(PN, DL, nullptr, DT, AC)) { if (SE) SE->forgetValue(PN); - PN->replaceAllUsesWith(V); - PN->eraseFromParent(); + if (!PreserveLCSSA || LI->replacementPreservesLCSSAForm(PN, V)) { + PN->replaceAllUsesWith(V); + PN->eraseFromParent(); + } } // If this loop has multiple exits and the exits all go to the same @@ -821,8 +818,8 @@ bool LoopSimplify::runOnFunction(Function &F) { if (PreserveLCSSA) { assert(DT && "DT not available."); assert(LI && "LI not available."); - bool InLCSSA = - all_of(*LI, [&](Loop *L) { return L->isRecursivelyLCSSAForm(*DT); }); + bool InLCSSA = all_of( + *LI, [&](Loop *L) { return L->isRecursivelyLCSSAForm(*DT, *LI); }); assert(InLCSSA && "Requested to preserve LCSSA, but it's already broken."); } #endif @@ -833,8 +830,8 @@ bool LoopSimplify::runOnFunction(Function &F) { #ifndef NDEBUG if (PreserveLCSSA) { - bool InLCSSA = - all_of(*LI, [&](Loop *L) { return L->isRecursivelyLCSSAForm(*DT); }); + bool InLCSSA = all_of( + *LI, [&](Loop *L) { return L->isRecursivelyLCSSAForm(*DT, *LI); }); assert(InLCSSA && "LCSSA is broken after loop-simplify."); } #endif @@ -842,7 +839,7 @@ bool LoopSimplify::runOnFunction(Function &F) { } PreservedAnalyses LoopSimplifyPass::run(Function &F, - AnalysisManager<Function> &AM) { + FunctionAnalysisManager &AM) { bool Changed = false; LoopInfo *LI = &AM.getResult<LoopAnalysis>(F); DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F); @@ -854,6 +851,10 @@ PreservedAnalyses LoopSimplifyPass::run(Function &F, for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) Changed |= simplifyLoop(*I, DT, LI, SE, AC, true /* PreserveLCSSA */); + // FIXME: We need to invalidate this to avoid PR28400. Is there a better + // solution? + AM.invalidate<ScalarEvolutionAnalysis>(F); + if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/contrib/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/contrib/llvm/lib/Transforms/Utils/LoopUnroll.cpp index 7f1f78f..e346ebd 100644 --- a/contrib/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -23,11 +23,12 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -46,7 +47,7 @@ STATISTIC(NumCompletelyUnrolled, "Number of loops completely unrolled"); STATISTIC(NumUnrolled, "Number of loops unrolled (completely or otherwise)"); static cl::opt<bool> -UnrollRuntimeEpilog("unroll-runtime-epilog", cl::init(true), cl::Hidden, +UnrollRuntimeEpilog("unroll-runtime-epilog", cl::init(false), cl::Hidden, cl::desc("Allow runtime unrolled loops to be unrolled " "with epilog instead of prolog.")); @@ -171,20 +172,58 @@ static bool needToInsertPhisForLCSSA(Loop *L, std::vector<BasicBlock *> Blocks, return false; } +/// Adds ClonedBB to LoopInfo, creates a new loop for ClonedBB if necessary +/// and adds a mapping from the original loop to the new loop to NewLoops. +/// Returns nullptr if no new loop was created and a pointer to the +/// original loop OriginalBB was part of otherwise. +const Loop* llvm::addClonedBlockToLoopInfo(BasicBlock *OriginalBB, + BasicBlock *ClonedBB, LoopInfo *LI, + NewLoopsMap &NewLoops) { + // Figure out which loop New is in. + const Loop *OldLoop = LI->getLoopFor(OriginalBB); + assert(OldLoop && "Should (at least) be in the loop being unrolled!"); + + Loop *&NewLoop = NewLoops[OldLoop]; + if (!NewLoop) { + // Found a new sub-loop. + assert(OriginalBB == OldLoop->getHeader() && + "Header should be first in RPO"); + + NewLoop = new Loop(); + Loop *NewLoopParent = NewLoops.lookup(OldLoop->getParentLoop()); + + if (NewLoopParent) + NewLoopParent->addChildLoop(NewLoop); + else + LI->addTopLevelLoop(NewLoop); + + NewLoop->addBasicBlockToLoop(ClonedBB, *LI); + return OldLoop; + } else { + NewLoop->addBasicBlockToLoop(ClonedBB, *LI); + return nullptr; + } +} + /// Unroll the given loop by Count. The loop must be in LCSSA form. Returns true /// if unrolling was successful, or false if the loop was unmodified. Unrolling /// can only fail when the loop's latch block is not terminated by a conditional /// branch instruction. However, if the trip count (and multiple) are not known, /// loop unrolling will mostly produce more code that is no faster. /// -/// TripCount is generally defined as the number of times the loop header -/// executes. UnrollLoop relaxes the definition to permit early exits: here -/// TripCount is the iteration on which control exits LatchBlock if no early -/// exits were taken. Note that UnrollLoop assumes that the loop counter test -/// terminates LatchBlock in order to remove unnecesssary instances of the -/// test. In other words, control may exit the loop prior to TripCount -/// iterations via an early branch, but control may not exit the loop from the -/// LatchBlock's terminator prior to TripCount iterations. +/// TripCount is the upper bound of the iteration on which control exits +/// LatchBlock. Control may exit the loop prior to TripCount iterations either +/// via an early branch in other loop block or via LatchBlock terminator. This +/// is relaxed from the general definition of trip count which is the number of +/// times the loop header executes. Note that UnrollLoop assumes that the loop +/// counter test is in LatchBlock in order to remove unnecesssary instances of +/// the test. If control can exit the loop from the LatchBlock's terminator +/// prior to TripCount iterations, flag PreserveCondBr needs to be set. +/// +/// PreserveCondBr indicates whether the conditional branch of the LatchBlock +/// needs to be preserved. It is needed when we use trip count upper bound to +/// fully unroll the loop. If PreserveOnlyFirst is also set then only the first +/// conditional branch needs to be preserved. /// /// Similarly, TripMultiple divides the number of times that the LatchBlock may /// execute without exiting the loop. @@ -196,15 +235,21 @@ static bool needToInsertPhisForLCSSA(Loop *L, std::vector<BasicBlock *> Blocks, /// runtime-unroll the loop if computing RuntimeTripCount will be expensive and /// AllowExpensiveTripCount is false. /// +/// If we want to perform PGO-based loop peeling, PeelCount is set to the +/// number of iterations we want to peel off. +/// /// The LoopInfo Analysis that is passed will be kept consistent. /// /// This utility preserves LoopInfo. It will also preserve ScalarEvolution and /// DominatorTree if they are non-null. bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, bool AllowRuntime, bool AllowExpensiveTripCount, - unsigned TripMultiple, LoopInfo *LI, ScalarEvolution *SE, - DominatorTree *DT, AssumptionCache *AC, + bool PreserveCondBr, bool PreserveOnlyFirst, + unsigned TripMultiple, unsigned PeelCount, LoopInfo *LI, + ScalarEvolution *SE, DominatorTree *DT, + AssumptionCache *AC, OptimizationRemarkEmitter *ORE, bool PreserveLCSSA) { + BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) { DEBUG(dbgs() << " Can't unroll; loop preheader-insertion failed.\n"); @@ -250,9 +295,8 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, if (TripCount != 0 && Count > TripCount) Count = TripCount; - // Don't enter the unroll code if there is nothing to do. This way we don't - // need to support "partial unrolling by 1". - if (TripCount == 0 && Count < 2) + // Don't enter the unroll code if there is nothing to do. + if (TripCount == 0 && Count < 2 && PeelCount == 0) return false; assert(Count > 0); @@ -272,14 +316,22 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, // now we just recompute LCSSA for the outer loop, but it should be possible // to fix it in-place. bool NeedToFixLCSSA = PreserveLCSSA && CompletelyUnroll && - std::any_of(ExitBlocks.begin(), ExitBlocks.end(), - [&](BasicBlock *BB) { return isa<PHINode>(BB->begin()); }); + any_of(ExitBlocks, [](const BasicBlock *BB) { + return isa<PHINode>(BB->begin()); + }); // We assume a run-time trip count if the compiler cannot // figure out the loop trip count and the unroll-runtime // flag is specified. bool RuntimeTripCount = (TripCount == 0 && Count > 0 && AllowRuntime); + assert((!RuntimeTripCount || !PeelCount) && + "Did not expect runtime trip-count unrolling " + "and peeling for the same loop"); + + if (PeelCount) + peelLoop(L, PeelCount, LI, SE, DT, PreserveLCSSA); + // Loops containing convergent instructions must have a count that divides // their TripMultiple. DEBUG( @@ -293,9 +345,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, "Unroll count must divide trip multiple if loop contains a " "convergent operation."); }); - // Don't output the runtime loop remainder if Count is a multiple of - // TripMultiple. Such a remainder is never needed, and is unsafe if the loop - // contains a convergent instruction. + if (RuntimeTripCount && TripMultiple % Count != 0 && !UnrollRuntimeLoopRemainder(L, Count, AllowExpensiveTripCount, UnrollRuntimeEpilog, LI, SE, DT, @@ -322,35 +372,40 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, (unsigned)GreatestCommonDivisor64(Count, TripMultiple); } + using namespace ore; // Report the unrolling decision. - DebugLoc LoopLoc = L->getStartLoc(); - Function *F = Header->getParent(); - LLVMContext &Ctx = F->getContext(); - if (CompletelyUnroll) { DEBUG(dbgs() << "COMPLETELY UNROLLING loop %" << Header->getName() << " with trip count " << TripCount << "!\n"); - emitOptimizationRemark(Ctx, DEBUG_TYPE, *F, LoopLoc, - Twine("completely unrolled loop with ") + - Twine(TripCount) + " iterations"); + ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(), + L->getHeader()) + << "completely unrolled loop with " + << NV("UnrollCount", TripCount) << " iterations"); + } else if (PeelCount) { + DEBUG(dbgs() << "PEELING loop %" << Header->getName() + << " with iteration count " << PeelCount << "!\n"); + ORE->emit(OptimizationRemark(DEBUG_TYPE, "Peeled", L->getStartLoc(), + L->getHeader()) + << " peeled loop by " << NV("PeelCount", PeelCount) + << " iterations"); } else { - auto EmitDiag = [&](const Twine &T) { - emitOptimizationRemark(Ctx, DEBUG_TYPE, *F, LoopLoc, - "unrolled loop by a factor of " + Twine(Count) + - T); - }; + OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(), + L->getHeader()); + Diag << "unrolled loop by a factor of " << NV("UnrollCount", Count); DEBUG(dbgs() << "UNROLLING loop %" << Header->getName() << " by " << Count); if (TripMultiple == 0 || BreakoutTrip != TripMultiple) { DEBUG(dbgs() << " with a breakout at trip " << BreakoutTrip); - EmitDiag(" with a breakout at trip " + Twine(BreakoutTrip)); + ORE->emit(Diag << " with a breakout at trip " + << NV("BreakoutTrip", BreakoutTrip)); } else if (TripMultiple != 1) { DEBUG(dbgs() << " with " << TripMultiple << " trips per branch"); - EmitDiag(" with " + Twine(TripMultiple) + " trips per branch"); + ORE->emit(Diag << " with " << NV("TripMultiple", TripMultiple) + << " trips per branch"); } else if (RuntimeTripCount) { DEBUG(dbgs() << " with run-time trip count"); - EmitDiag(" with run-time trip count"); + ORE->emit(Diag << " with run-time trip count"); } DEBUG(dbgs() << "!\n"); } @@ -382,6 +437,15 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO(); std::vector<BasicBlock*> UnrolledLoopBlocks = L->getBlocks(); + + // Loop Unrolling might create new loops. While we do preserve LoopInfo, we + // might break loop-simplified form for these loops (as they, e.g., would + // share the same exit blocks). We'll keep track of loops for which we can + // break this so that later we can re-simplify them. + SmallSetVector<Loop *, 4> LoopsToSimplify; + for (Loop *SubLoop : *L) + LoopsToSimplify.insert(SubLoop); + for (unsigned It = 1; It != Count; ++It) { std::vector<BasicBlock*> NewBlocks; SmallDenseMap<const Loop *, Loop *, 4> NewLoops; @@ -397,27 +461,14 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, assert(LI->getLoopFor(*BB) == L && "Header should not be in a sub-loop"); L->addBasicBlockToLoop(New, *LI); } else { - // Figure out which loop New is in. - const Loop *OldLoop = LI->getLoopFor(*BB); - assert(OldLoop && "Should (at least) be in the loop being unrolled!"); - - Loop *&NewLoop = NewLoops[OldLoop]; - if (!NewLoop) { - // Found a new sub-loop. - assert(*BB == OldLoop->getHeader() && - "Header should be first in RPO"); - - Loop *NewLoopParent = NewLoops.lookup(OldLoop->getParentLoop()); - assert(NewLoopParent && - "Expected parent loop before sub-loop in RPO"); - NewLoop = new Loop; - NewLoopParent->addChildLoop(NewLoop); + const Loop *OldLoop = addClonedBlockToLoopInfo(*BB, New, LI, NewLoops); + if (OldLoop) { + LoopsToSimplify.insert(NewLoops[OldLoop]); // Forget the old loop, since its inputs may have changed. if (SE) SE->forgetLoop(OldLoop); } - NewLoop->addBasicBlockToLoop(New, *LI); } if (*BB == Header) @@ -480,9 +531,14 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, } // Remap all instructions in the most recent iteration - for (BasicBlock *NewBlock : NewBlocks) - for (Instruction &I : *NewBlock) + for (BasicBlock *NewBlock : NewBlocks) { + for (Instruction &I : *NewBlock) { ::remapInstruction(&I, LastValueMap); + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::assume) + AC->registerAssumption(II); + } + } } // Loop over the PHI nodes in the original block, setting incoming values. @@ -524,12 +580,16 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, if (CompletelyUnroll) { if (j == 0) Dest = LoopExit; - NeedConditional = false; - } - - // If we know the trip count or a multiple of it, we can safely use an - // unconditional branch for some iterations. - if (j != BreakoutTrip && (TripMultiple == 0 || j % TripMultiple != 0)) { + // If using trip count upper bound to completely unroll, we need to keep + // the conditional branch except the last one because the loop may exit + // after any iteration. + assert(NeedConditional && + "NeedCondition cannot be modified by both complete " + "unrolling and runtime unrolling"); + NeedConditional = (PreserveCondBr && j && !(PreserveOnlyFirst && i != 0)); + } else if (j != BreakoutTrip && (TripMultiple == 0 || j % TripMultiple != 0)) { + // If we know the trip count or a multiple of it, we can safely use an + // unconditional branch for some iterations. NeedConditional = false; } @@ -595,10 +655,6 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, } } - // FIXME: We could register any cloned assumptions instead of clearing the - // whole function's cache. - AC->clear(); - // FIXME: We only preserve DT info for complete unrolling now. Incrementally // updating domtree after partial loop unrolling should also be easy. if (DT && !CompletelyUnroll) @@ -607,7 +663,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, DEBUG(DT->verifyDomTree()); // Simplify any new induction variables in the partially unrolled loop. - if (SE && !CompletelyUnroll) { + if (SE && !CompletelyUnroll && Count > 1) { SmallVector<WeakVH, 16> DeadInsts; simplifyLoopIVs(L, SE, DT, LI, DeadInsts); @@ -636,6 +692,11 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, } } + // TODO: after peeling or unrolling, previously loop variant conditions are + // likely to fold to constants, eagerly propagating those here will require + // fewer cleanup passes to be run. Alternatively, a LoopEarlyCSE might be + // appropriate. + NumCompletelyUnrolled += CompletelyUnroll; ++NumUnrolled; @@ -663,6 +724,11 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, if (!OuterL && !CompletelyUnroll) OuterL = L; if (OuterL) { + // OuterL includes all loops for which we can break loop-simplify, so + // it's sufficient to simplify only it (it'll recursively simplify inner + // loops too). + // TODO: That potentially might be compile-time expensive. We should try + // to fix the loop-simplified form incrementally. simplifyLoop(OuterL, DT, LI, SE, AC, PreserveLCSSA); // LCSSA must be performed on the outermost affected loop. The unrolled @@ -678,6 +744,10 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, else assert(OuterL->isLCSSAForm(*DT) && "Loops should be in LCSSA form after loop-unroll."); + } else { + // Simplify loops for which we might've broken loop-simplify form. + for (Loop *SubLoop : LoopsToSimplify) + simplifyLoop(SubLoop, DT, LI, SE, AC, PreserveLCSSA); } } diff --git a/contrib/llvm/lib/Transforms/Utils/LoopUnrollPeel.cpp b/contrib/llvm/lib/Transforms/Utils/LoopUnrollPeel.cpp new file mode 100644 index 0000000..842cf31 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Utils/LoopUnrollPeel.cpp @@ -0,0 +1,414 @@ +//===-- UnrollLoopPeel.cpp - Loop peeling utilities -----------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements some loop unrolling utilities for peeling loops +// with dynamically inferred (from PGO) trip counts. See LoopUnroll.cpp for +// unrolling loops with compile-time constant trip counts. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/UnrollLoop.h" +#include <algorithm> + +using namespace llvm; + +#define DEBUG_TYPE "loop-unroll" +STATISTIC(NumPeeled, "Number of loops peeled"); + +static cl::opt<unsigned> UnrollPeelMaxCount( + "unroll-peel-max-count", cl::init(7), cl::Hidden, + cl::desc("Max average trip count which will cause loop peeling.")); + +static cl::opt<unsigned> UnrollForcePeelCount( + "unroll-force-peel-count", cl::init(0), cl::Hidden, + cl::desc("Force a peel count regardless of profiling information.")); + +// Check whether we are capable of peeling this loop. +static bool canPeel(Loop *L) { + // Make sure the loop is in simplified form + if (!L->isLoopSimplifyForm()) + return false; + + // Only peel loops that contain a single exit + if (!L->getExitingBlock() || !L->getUniqueExitBlock()) + return false; + + return true; +} + +// Return the number of iterations we want to peel off. +void llvm::computePeelCount(Loop *L, unsigned LoopSize, + TargetTransformInfo::UnrollingPreferences &UP) { + UP.PeelCount = 0; + if (!canPeel(L)) + return; + + // Only try to peel innermost loops. + if (!L->empty()) + return; + + // If the user provided a peel count, use that. + bool UserPeelCount = UnrollForcePeelCount.getNumOccurrences() > 0; + if (UserPeelCount) { + DEBUG(dbgs() << "Force-peeling first " << UnrollForcePeelCount + << " iterations.\n"); + UP.PeelCount = UnrollForcePeelCount; + return; + } + + // If we don't know the trip count, but have reason to believe the average + // trip count is low, peeling should be beneficial, since we will usually + // hit the peeled section. + // We only do this in the presence of profile information, since otherwise + // our estimates of the trip count are not reliable enough. + if (UP.AllowPeeling && L->getHeader()->getParent()->getEntryCount()) { + Optional<unsigned> PeelCount = getLoopEstimatedTripCount(L); + if (!PeelCount) + return; + + DEBUG(dbgs() << "Profile-based estimated trip count is " << *PeelCount + << "\n"); + + if (*PeelCount) { + if ((*PeelCount <= UnrollPeelMaxCount) && + (LoopSize * (*PeelCount + 1) <= UP.Threshold)) { + DEBUG(dbgs() << "Peeling first " << *PeelCount << " iterations.\n"); + UP.PeelCount = *PeelCount; + return; + } + DEBUG(dbgs() << "Requested peel count: " << *PeelCount << "\n"); + DEBUG(dbgs() << "Max peel count: " << UnrollPeelMaxCount << "\n"); + DEBUG(dbgs() << "Peel cost: " << LoopSize * (*PeelCount + 1) << "\n"); + DEBUG(dbgs() << "Max peel cost: " << UP.Threshold << "\n"); + } + } + + return; +} + +/// \brief Update the branch weights of the latch of a peeled-off loop +/// iteration. +/// This sets the branch weights for the latch of the recently peeled off loop +/// iteration correctly. +/// Our goal is to make sure that: +/// a) The total weight of all the copies of the loop body is preserved. +/// b) The total weight of the loop exit is preserved. +/// c) The body weight is reasonably distributed between the peeled iterations. +/// +/// \param Header The copy of the header block that belongs to next iteration. +/// \param LatchBR The copy of the latch branch that belongs to this iteration. +/// \param IterNumber The serial number of the iteration that was just +/// peeled off. +/// \param AvgIters The average number of iterations we expect the loop to have. +/// \param[in,out] PeeledHeaderWeight The total number of dynamic loop +/// iterations that are unaccounted for. As an input, it represents the number +/// of times we expect to enter the header of the iteration currently being +/// peeled off. The output is the number of times we expect to enter the +/// header of the next iteration. +static void updateBranchWeights(BasicBlock *Header, BranchInst *LatchBR, + unsigned IterNumber, unsigned AvgIters, + uint64_t &PeeledHeaderWeight) { + + // FIXME: Pick a more realistic distribution. + // Currently the proportion of weight we assign to the fall-through + // side of the branch drops linearly with the iteration number, and we use + // a 0.9 fudge factor to make the drop-off less sharp... + if (PeeledHeaderWeight) { + uint64_t FallThruWeight = + PeeledHeaderWeight * ((float)(AvgIters - IterNumber) / AvgIters * 0.9); + uint64_t ExitWeight = PeeledHeaderWeight - FallThruWeight; + PeeledHeaderWeight -= ExitWeight; + + unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1); + MDBuilder MDB(LatchBR->getContext()); + MDNode *WeightNode = + HeaderIdx ? MDB.createBranchWeights(ExitWeight, FallThruWeight) + : MDB.createBranchWeights(FallThruWeight, ExitWeight); + LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); + } +} + +/// \brief Clones the body of the loop L, putting it between \p InsertTop and \p +/// InsertBot. +/// \param IterNumber The serial number of the iteration currently being +/// peeled off. +/// \param Exit The exit block of the original loop. +/// \param[out] NewBlocks A list of the the blocks in the newly created clone +/// \param[out] VMap The value map between the loop and the new clone. +/// \param LoopBlocks A helper for DFS-traversal of the loop. +/// \param LVMap A value-map that maps instructions from the original loop to +/// instructions in the last peeled-off iteration. +static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop, + BasicBlock *InsertBot, BasicBlock *Exit, + SmallVectorImpl<BasicBlock *> &NewBlocks, + LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap, + ValueToValueMapTy &LVMap, LoopInfo *LI) { + + BasicBlock *Header = L->getHeader(); + BasicBlock *Latch = L->getLoopLatch(); + BasicBlock *PreHeader = L->getLoopPreheader(); + + Function *F = Header->getParent(); + LoopBlocksDFS::RPOIterator BlockBegin = LoopBlocks.beginRPO(); + LoopBlocksDFS::RPOIterator BlockEnd = LoopBlocks.endRPO(); + Loop *ParentLoop = L->getParentLoop(); + + // For each block in the original loop, create a new copy, + // and update the value map with the newly created values. + for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) { + BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, ".peel", F); + NewBlocks.push_back(NewBB); + + if (ParentLoop) + ParentLoop->addBasicBlockToLoop(NewBB, *LI); + + VMap[*BB] = NewBB; + } + + // Hook-up the control flow for the newly inserted blocks. + // The new header is hooked up directly to the "top", which is either + // the original loop preheader (for the first iteration) or the previous + // iteration's exiting block (for every other iteration) + InsertTop->getTerminator()->setSuccessor(0, cast<BasicBlock>(VMap[Header])); + + // Similarly, for the latch: + // The original exiting edge is still hooked up to the loop exit. + // The backedge now goes to the "bottom", which is either the loop's real + // header (for the last peeled iteration) or the copied header of the next + // iteration (for every other iteration) + BranchInst *LatchBR = + cast<BranchInst>(cast<BasicBlock>(VMap[Latch])->getTerminator()); + unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1); + LatchBR->setSuccessor(HeaderIdx, InsertBot); + LatchBR->setSuccessor(1 - HeaderIdx, Exit); + + // The new copy of the loop body starts with a bunch of PHI nodes + // that pick an incoming value from either the preheader, or the previous + // loop iteration. Since this copy is no longer part of the loop, we + // resolve this statically: + // For the first iteration, we use the value from the preheader directly. + // For any other iteration, we replace the phi with the value generated by + // the immediately preceding clone of the loop body (which represents + // the previous iteration). + for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) { + PHINode *NewPHI = cast<PHINode>(VMap[&*I]); + if (IterNumber == 0) { + VMap[&*I] = NewPHI->getIncomingValueForBlock(PreHeader); + } else { + Value *LatchVal = NewPHI->getIncomingValueForBlock(Latch); + Instruction *LatchInst = dyn_cast<Instruction>(LatchVal); + if (LatchInst && L->contains(LatchInst)) + VMap[&*I] = LVMap[LatchInst]; + else + VMap[&*I] = LatchVal; + } + cast<BasicBlock>(VMap[Header])->getInstList().erase(NewPHI); + } + + // Fix up the outgoing values - we need to add a value for the iteration + // we've just created. Note that this must happen *after* the incoming + // values are adjusted, since the value going out of the latch may also be + // a value coming into the header. + for (BasicBlock::iterator I = Exit->begin(); isa<PHINode>(I); ++I) { + PHINode *PHI = cast<PHINode>(I); + Value *LatchVal = PHI->getIncomingValueForBlock(Latch); + Instruction *LatchInst = dyn_cast<Instruction>(LatchVal); + if (LatchInst && L->contains(LatchInst)) + LatchVal = VMap[LatchVal]; + PHI->addIncoming(LatchVal, cast<BasicBlock>(VMap[Latch])); + } + + // LastValueMap is updated with the values for the current loop + // which are used the next time this function is called. + for (const auto &KV : VMap) + LVMap[KV.first] = KV.second; +} + +/// \brief Peel off the first \p PeelCount iterations of loop \p L. +/// +/// Note that this does not peel them off as a single straight-line block. +/// Rather, each iteration is peeled off separately, and needs to check the +/// exit condition. +/// For loops that dynamically execute \p PeelCount iterations or less +/// this provides a benefit, since the peeled off iterations, which account +/// for the bulk of dynamic execution, can be further simplified by scalar +/// optimizations. +bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, + ScalarEvolution *SE, DominatorTree *DT, + bool PreserveLCSSA) { + if (!canPeel(L)) + return false; + + LoopBlocksDFS LoopBlocks(L); + LoopBlocks.perform(LI); + + BasicBlock *Header = L->getHeader(); + BasicBlock *PreHeader = L->getLoopPreheader(); + BasicBlock *Latch = L->getLoopLatch(); + BasicBlock *Exit = L->getUniqueExitBlock(); + + Function *F = Header->getParent(); + + // Set up all the necessary basic blocks. It is convenient to split the + // preheader into 3 parts - two blocks to anchor the peeled copy of the loop + // body, and a new preheader for the "real" loop. + + // Peeling the first iteration transforms. + // + // PreHeader: + // ... + // Header: + // LoopBody + // If (cond) goto Header + // Exit: + // + // into + // + // InsertTop: + // LoopBody + // If (!cond) goto Exit + // InsertBot: + // NewPreHeader: + // ... + // Header: + // LoopBody + // If (cond) goto Header + // Exit: + // + // Each following iteration will split the current bottom anchor in two, + // and put the new copy of the loop body between these two blocks. That is, + // after peeling another iteration from the example above, we'll split + // InsertBot, and get: + // + // InsertTop: + // LoopBody + // If (!cond) goto Exit + // InsertBot: + // LoopBody + // If (!cond) goto Exit + // InsertBot.next: + // NewPreHeader: + // ... + // Header: + // LoopBody + // If (cond) goto Header + // Exit: + + BasicBlock *InsertTop = SplitEdge(PreHeader, Header, DT, LI); + BasicBlock *InsertBot = + SplitBlock(InsertTop, InsertTop->getTerminator(), DT, LI); + BasicBlock *NewPreHeader = + SplitBlock(InsertBot, InsertBot->getTerminator(), DT, LI); + + InsertTop->setName(Header->getName() + ".peel.begin"); + InsertBot->setName(Header->getName() + ".peel.next"); + NewPreHeader->setName(PreHeader->getName() + ".peel.newph"); + + ValueToValueMapTy LVMap; + + // If we have branch weight information, we'll want to update it for the + // newly created branches. + BranchInst *LatchBR = + cast<BranchInst>(cast<BasicBlock>(Latch)->getTerminator()); + unsigned HeaderIdx = (LatchBR->getSuccessor(0) == Header ? 0 : 1); + + uint64_t TrueWeight, FalseWeight; + uint64_t ExitWeight = 0, CurHeaderWeight = 0; + if (LatchBR->extractProfMetadata(TrueWeight, FalseWeight)) { + ExitWeight = HeaderIdx ? TrueWeight : FalseWeight; + // The # of times the loop body executes is the sum of the exit block + // weight and the # of times the backedges are taken. + CurHeaderWeight = TrueWeight + FalseWeight; + } + + // For each peeled-off iteration, make a copy of the loop. + for (unsigned Iter = 0; Iter < PeelCount; ++Iter) { + SmallVector<BasicBlock *, 8> NewBlocks; + ValueToValueMapTy VMap; + + // Subtract the exit weight from the current header weight -- the exit + // weight is exactly the weight of the previous iteration's header. + // FIXME: due to the way the distribution is constructed, we need a + // guard here to make sure we don't end up with non-positive weights. + if (ExitWeight < CurHeaderWeight) + CurHeaderWeight -= ExitWeight; + else + CurHeaderWeight = 1; + + cloneLoopBlocks(L, Iter, InsertTop, InsertBot, Exit, + NewBlocks, LoopBlocks, VMap, LVMap, LI); + updateBranchWeights(InsertBot, cast<BranchInst>(VMap[LatchBR]), Iter, + PeelCount, ExitWeight); + + InsertTop = InsertBot; + InsertBot = SplitBlock(InsertBot, InsertBot->getTerminator(), DT, LI); + InsertBot->setName(Header->getName() + ".peel.next"); + + F->getBasicBlockList().splice(InsertTop->getIterator(), + F->getBasicBlockList(), + NewBlocks[0]->getIterator(), F->end()); + + // Remap to use values from the current iteration instead of the + // previous one. + remapInstructionsInBlocks(NewBlocks, VMap); + } + + // Now adjust the phi nodes in the loop header to get their initial values + // from the last peeled-off iteration instead of the preheader. + for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) { + PHINode *PHI = cast<PHINode>(I); + Value *NewVal = PHI->getIncomingValueForBlock(Latch); + Instruction *LatchInst = dyn_cast<Instruction>(NewVal); + if (LatchInst && L->contains(LatchInst)) + NewVal = LVMap[LatchInst]; + + PHI->setIncomingValue(PHI->getBasicBlockIndex(NewPreHeader), NewVal); + } + + // Adjust the branch weights on the loop exit. + if (ExitWeight) { + // The backedge count is the difference of current header weight and + // current loop exit weight. If the current header weight is smaller than + // the current loop exit weight, we mark the loop backedge weight as 1. + uint64_t BackEdgeWeight = 0; + if (ExitWeight < CurHeaderWeight) + BackEdgeWeight = CurHeaderWeight - ExitWeight; + else + BackEdgeWeight = 1; + MDBuilder MDB(LatchBR->getContext()); + MDNode *WeightNode = + HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight) + : MDB.createBranchWeights(BackEdgeWeight, ExitWeight); + LatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); + } + + // If the loop is nested, we changed the parent loop, update SE. + if (Loop *ParentLoop = L->getParentLoop()) + SE->forgetLoop(ParentLoop); + + NumPeeled++; + + return true; +} diff --git a/contrib/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/contrib/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index 861a50c..d3ea156 100644 --- a/contrib/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -112,6 +112,18 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, } } + // Make sure that created prolog loop is in simplified form + SmallVector<BasicBlock *, 4> PrologExitPreds; + Loop *PrologLoop = LI->getLoopFor(PrologLatch); + if (PrologLoop) { + for (BasicBlock *PredBB : predecessors(PrologExit)) + if (PrologLoop->contains(PredBB)) + PrologExitPreds.push_back(PredBB); + + SplitBlockPredecessors(PrologExit, PrologExitPreds, ".unr-lcssa", DT, LI, + PreserveLCSSA); + } + // Create a branch around the original loop, which is taken if there are no // iterations remaining to be executed after running the prologue. Instruction *InsertPt = PrologExit->getTerminator(); @@ -289,16 +301,23 @@ static void CloneLoopBlocks(Loop *L, Value *NewIter, LI->addTopLevelLoop(NewLoop); } + NewLoopsMap NewLoops; + if (NewLoop) + NewLoops[L] = NewLoop; + else if (ParentLoop) + NewLoops[L] = ParentLoop; + // For each block in the original loop, create a new copy, // and update the value map with the newly created values. for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) { BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, "." + suffix, F); NewBlocks.push_back(NewBB); - - if (NewLoop) - NewLoop->addBasicBlockToLoop(NewBB, *LI); - else if (ParentLoop) - ParentLoop->addBasicBlockToLoop(NewBB, *LI); + + // If we're unrolling the outermost loop, there's no remainder loop, + // and this block isn't in a nested loop, then the new block is not + // in any loop. Otherwise, add it to loopinfo. + if (CreateRemainderLoop || LI->getLoopFor(*BB) != L || ParentLoop) + addClonedBlockToLoopInfo(*BB, NewBB, LI, NewLoops); VMap[*BB] = NewBB; if (Header == *BB) { @@ -479,11 +498,6 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, if (Log2_32(Count) > BEWidth) return false; - // If this loop is nested, then the loop unroller changes the code in the - // parent loop, so the Scalar Evolution pass needs to be run again. - if (Loop *ParentLoop = L->getParentLoop()) - SE->forgetLoop(ParentLoop); - BasicBlock *Latch = L->getLoopLatch(); // Loop structure is the following: @@ -673,6 +687,12 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, ConnectProlog(L, BECount, Count, PrologExit, PreHeader, NewPreHeader, VMap, DT, LI, PreserveLCSSA); } + + // If this loop is nested, then the loop unroller changes the code in the + // parent loop, so the Scalar Evolution pass needs to be run again. + if (Loop *ParentLoop = L->getParentLoop()) + SE->forgetLoop(ParentLoop); + NumRuntimeUnrolled++; return true; } diff --git a/contrib/llvm/lib/Transforms/Utils/LoopUtils.cpp b/contrib/llvm/lib/Transforms/Utils/LoopUtils.cpp index 3902c67..c8efa9e 100644 --- a/contrib/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -11,14 +11,17 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BasicAliasAnalysis.h" -#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" -#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -26,7 +29,6 @@ #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" -#include "llvm/Transforms/Utils/LoopUtils.h" using namespace llvm; using namespace llvm::PatternMatch; @@ -305,7 +307,7 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind, // The instruction used by an outside user must be the last instruction // before we feed back to the reduction phi. Otherwise, we loose VF-1 // operations on the value. - if (std::find(Phi->op_begin(), Phi->op_end(), Cur) == Phi->op_end()) + if (!is_contained(Phi->operands(), Cur)) return false; ExitInstruction = Cur; @@ -654,8 +656,8 @@ Value *RecurrenceDescriptor::createMinMaxOp(IRBuilder<> &Builder, } InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K, - const SCEV *Step) - : StartValue(Start), IK(K), Step(Step) { + const SCEV *Step, BinaryOperator *BOp) + : StartValue(Start), IK(K), Step(Step), InductionBinOp(BOp) { assert(IK != IK_NoInduction && "Not an induction"); // Start value type should match the induction kind and the value @@ -672,7 +674,15 @@ InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K, assert((IK != IK_PtrInduction || getConstIntStepValue()) && "Step value should be constant for pointer induction"); - assert(Step->getType()->isIntegerTy() && "StepValue is not an integer"); + assert((IK == IK_FpInduction || Step->getType()->isIntegerTy()) && + "StepValue is not an integer"); + + assert((IK != IK_FpInduction || Step->getType()->isFloatingPointTy()) && + "StepValue is not FP for FpInduction"); + assert((IK != IK_FpInduction || (InductionBinOp && + (InductionBinOp->getOpcode() == Instruction::FAdd || + InductionBinOp->getOpcode() == Instruction::FSub))) && + "Binary opcode should be specified for FP induction"); } int InductionDescriptor::getConsecutiveDirection() const { @@ -693,6 +703,8 @@ Value *InductionDescriptor::transform(IRBuilder<> &B, Value *Index, const DataLayout& DL) const { SCEVExpander Exp(*SE, DL, "induction"); + assert(Index->getType() == Step->getType() && + "Index type does not match StepValue type"); switch (IK) { case IK_IntInduction: { assert(Index->getType() == StartValue->getType() && @@ -717,29 +729,113 @@ Value *InductionDescriptor::transform(IRBuilder<> &B, Value *Index, return Exp.expandCodeFor(S, StartValue->getType(), &*B.GetInsertPoint()); } case IK_PtrInduction: { - assert(Index->getType() == Step->getType() && - "Index type does not match StepValue type"); assert(isa<SCEVConstant>(Step) && "Expected constant step for pointer induction"); const SCEV *S = SE->getMulExpr(SE->getSCEV(Index), Step); Index = Exp.expandCodeFor(S, Index->getType(), &*B.GetInsertPoint()); return B.CreateGEP(nullptr, StartValue, Index); } + case IK_FpInduction: { + assert(Step->getType()->isFloatingPointTy() && "Expected FP Step value"); + assert(InductionBinOp && + (InductionBinOp->getOpcode() == Instruction::FAdd || + InductionBinOp->getOpcode() == Instruction::FSub) && + "Original bin op should be defined for FP induction"); + + Value *StepValue = cast<SCEVUnknown>(Step)->getValue(); + + // Floating point operations had to be 'fast' to enable the induction. + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + + Value *MulExp = B.CreateFMul(StepValue, Index); + if (isa<Instruction>(MulExp)) + // We have to check, the MulExp may be a constant. + cast<Instruction>(MulExp)->setFastMathFlags(Flags); + + Value *BOp = B.CreateBinOp(InductionBinOp->getOpcode() , StartValue, + MulExp, "induction"); + if (isa<Instruction>(BOp)) + cast<Instruction>(BOp)->setFastMathFlags(Flags); + + return BOp; + } case IK_NoInduction: return nullptr; } llvm_unreachable("invalid enum"); } -bool InductionDescriptor::isInductionPHI(PHINode *Phi, +bool InductionDescriptor::isFPInductionPHI(PHINode *Phi, const Loop *TheLoop, + ScalarEvolution *SE, + InductionDescriptor &D) { + + // Here we only handle FP induction variables. + assert(Phi->getType()->isFloatingPointTy() && "Unexpected Phi type"); + + if (TheLoop->getHeader() != Phi->getParent()) + return false; + + // The loop may have multiple entrances or multiple exits; we can analyze + // this phi if it has a unique entry value and a unique backedge value. + if (Phi->getNumIncomingValues() != 2) + return false; + Value *BEValue = nullptr, *StartValue = nullptr; + if (TheLoop->contains(Phi->getIncomingBlock(0))) { + BEValue = Phi->getIncomingValue(0); + StartValue = Phi->getIncomingValue(1); + } else { + assert(TheLoop->contains(Phi->getIncomingBlock(1)) && + "Unexpected Phi node in the loop"); + BEValue = Phi->getIncomingValue(1); + StartValue = Phi->getIncomingValue(0); + } + + BinaryOperator *BOp = dyn_cast<BinaryOperator>(BEValue); + if (!BOp) + return false; + + Value *Addend = nullptr; + if (BOp->getOpcode() == Instruction::FAdd) { + if (BOp->getOperand(0) == Phi) + Addend = BOp->getOperand(1); + else if (BOp->getOperand(1) == Phi) + Addend = BOp->getOperand(0); + } else if (BOp->getOpcode() == Instruction::FSub) + if (BOp->getOperand(0) == Phi) + Addend = BOp->getOperand(1); + + if (!Addend) + return false; + + // The addend should be loop invariant + if (auto *I = dyn_cast<Instruction>(Addend)) + if (TheLoop->contains(I)) + return false; + + // FP Step has unknown SCEV + const SCEV *Step = SE->getUnknown(Addend); + D = InductionDescriptor(StartValue, IK_FpInduction, Step, BOp); + return true; +} + +bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, PredicatedScalarEvolution &PSE, InductionDescriptor &D, bool Assume) { Type *PhiTy = Phi->getType(); - // We only handle integer and pointer inductions variables. - if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy()) + + // Handle integer and pointer inductions variables. + // Now we handle also FP induction but not trying to make a + // recurrent expression from the PHI node in-place. + + if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy() && + !PhiTy->isFloatTy() && !PhiTy->isDoubleTy() && !PhiTy->isHalfTy()) return false; + if (PhiTy->isFloatingPointTy()) + return isFPInductionPHI(Phi, TheLoop, PSE.getSE(), D); + const SCEV *PhiScev = PSE.getSCEV(Phi); const auto *AR = dyn_cast<SCEVAddRecExpr>(PhiScev); @@ -752,10 +848,10 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, return false; } - return isInductionPHI(Phi, PSE.getSE(), D, AR); + return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR); } -bool InductionDescriptor::isInductionPHI(PHINode *Phi, +bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, ScalarEvolution *SE, InductionDescriptor &D, const SCEV *Expr) { @@ -773,15 +869,20 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, return false; } - assert(AR->getLoop()->getHeader() == Phi->getParent() && - "PHI is an AddRec for a different loop?!"); + if (AR->getLoop() != TheLoop) { + // FIXME: We should treat this as a uniform. Unfortunately, we + // don't currently know how to handled uniform PHIs. + DEBUG(dbgs() << "LV: PHI is a recurrence with respect to an outer loop.\n"); + return false; + } + Value *StartValue = Phi->getIncomingValueForBlock(AR->getLoop()->getLoopPreheader()); const SCEV *Step = AR->getStepRecurrence(*SE); // Calculate the pointer stride and check if it is consecutive. // The stride may be a constant or a loop invariant integer value. const SCEVConstant *ConstStep = dyn_cast<SCEVConstant>(Step); - if (!ConstStep && !SE->isLoopInvariant(Step, AR->getLoop())) + if (!ConstStep && !SE->isLoopInvariant(Step, TheLoop)) return false; if (PhiTy->isIntegerTy()) { @@ -824,7 +925,7 @@ SmallVector<Instruction *, 8> llvm::findDefsUsedOutsideOfLoop(Loop *L) { // be adapted into a pointer. for (auto &Inst : *Block) { auto Users = Inst.users(); - if (std::any_of(Users.begin(), Users.end(), [&](User *U) { + if (any_of(Users, [&](User *U) { auto *Use = cast<Instruction>(U); return !L->contains(Use->getParent()); })) @@ -851,6 +952,10 @@ void llvm::getLoopAnalysisUsage(AnalysisUsage &AU) { AU.addPreservedID(LoopSimplifyID); AU.addRequiredID(LCSSAID); AU.addPreservedID(LCSSAID); + // This is used in the LPPassManager to perform LCSSA verification on passes + // which preserve lcssa form + AU.addRequired<LCSSAVerificationPass>(); + AU.addPreserved<LCSSAVerificationPass>(); // Loop passes are designed to run inside of a loop pass manager which means // that any function analyses they require must be required by the first loop @@ -967,3 +1072,36 @@ bool llvm::isGuaranteedToExecute(const Instruction &Inst, // just a special case of this.) return true; } + +Optional<unsigned> llvm::getLoopEstimatedTripCount(Loop *L) { + // Only support loops with a unique exiting block, and a latch. + if (!L->getExitingBlock()) + return None; + + // Get the branch weights for the the loop's backedge. + BranchInst *LatchBR = + dyn_cast<BranchInst>(L->getLoopLatch()->getTerminator()); + if (!LatchBR || LatchBR->getNumSuccessors() != 2) + return None; + + assert((LatchBR->getSuccessor(0) == L->getHeader() || + LatchBR->getSuccessor(1) == L->getHeader()) && + "At least one edge out of the latch must go to the header"); + + // To estimate the number of times the loop body was executed, we want to + // know the number of times the backedge was taken, vs. the number of times + // we exited the loop. + uint64_t TrueVal, FalseVal; + if (!LatchBR->extractProfMetadata(TrueVal, FalseVal)) + return None; + + if (!TrueVal || !FalseVal) + return 0; + + // Divide the count of the backedge by the count of the edge exiting the loop, + // rounding to nearest. + if (LatchBR->getSuccessor(0) == L->getHeader()) + return (TrueVal + (FalseVal / 2)) / FalseVal; + else + return (FalseVal + (TrueVal / 2)) / TrueVal; +} diff --git a/contrib/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/contrib/llvm/lib/Transforms/Utils/LoopVersioning.cpp index b3c6169..29756d9 100644 --- a/contrib/llvm/lib/Transforms/Utils/LoopVersioning.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LoopVersioning.cpp @@ -36,7 +36,7 @@ LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, : VersionedLoop(L), NonVersionedLoop(nullptr), LAI(LAI), LI(LI), DT(DT), SE(SE) { assert(L->getExitBlock() && "No single exit block"); - assert(L->getLoopPreheader() && "No preheader"); + assert(L->isLoopSimplifyForm() && "Loop is not in loop-simplify form"); if (UseLAIChecks) { setAliasChecks(LAI.getRuntimePointerChecking()->getChecks()); setSCEVChecks(LAI.getPSE().getUnionPredicate()); @@ -278,8 +278,8 @@ public: bool Changed = false; for (Loop *L : Worklist) { const LoopAccessInfo &LAI = LAA->getInfo(L); - if (LAI.getNumRuntimePointerChecks() || - !LAI.getPSE().getUnionPredicate().isAlwaysTrue()) { + if (L->isLoopSimplifyForm() && (LAI.getNumRuntimePointerChecks() || + !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) { LoopVersioning LVer(LAI, L, LI, DT, SE); LVer.versionLoop(); LVer.annotateLoopWithNoAlias(); diff --git a/contrib/llvm/lib/Transforms/Utils/LowerInvoke.cpp b/contrib/llvm/lib/Transforms/Utils/LowerInvoke.cpp index 1b31c5a..ee84541 100644 --- a/contrib/llvm/lib/Transforms/Utils/LowerInvoke.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LowerInvoke.cpp @@ -14,6 +14,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Utils/LowerInvoke.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Instructions.h" @@ -28,36 +29,29 @@ using namespace llvm; STATISTIC(NumInvokes, "Number of invokes replaced"); namespace { - class LowerInvoke : public FunctionPass { + class LowerInvokeLegacyPass : public FunctionPass { public: static char ID; // Pass identification, replacement for typeid - explicit LowerInvoke() : FunctionPass(ID) { - initializeLowerInvokePass(*PassRegistry::getPassRegistry()); + explicit LowerInvokeLegacyPass() : FunctionPass(ID) { + initializeLowerInvokeLegacyPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override; }; } -char LowerInvoke::ID = 0; -INITIALIZE_PASS(LowerInvoke, "lowerinvoke", +char LowerInvokeLegacyPass::ID = 0; +INITIALIZE_PASS(LowerInvokeLegacyPass, "lowerinvoke", "Lower invoke and unwind, for unwindless code generators", false, false) -char &llvm::LowerInvokePassID = LowerInvoke::ID; - -// Public Interface To the LowerInvoke pass. -FunctionPass *llvm::createLowerInvokePass() { - return new LowerInvoke(); -} - -bool LowerInvoke::runOnFunction(Function &F) { +static bool runImpl(Function &F) { bool Changed = false; for (BasicBlock &BB : F) if (InvokeInst *II = dyn_cast<InvokeInst>(BB.getTerminator())) { - SmallVector<Value*,16> CallArgs(II->op_begin(), II->op_end() - 3); + SmallVector<Value *, 16> CallArgs(II->op_begin(), II->op_end() - 3); // Insert a normal call instruction... - CallInst *NewCall = CallInst::Create(II->getCalledValue(), - CallArgs, "", II); + CallInst *NewCall = + CallInst::Create(II->getCalledValue(), CallArgs, "", II); NewCall->takeName(II); NewCall->setCallingConv(II->getCallingConv()); NewCall->setAttributes(II->getAttributes()); @@ -73,7 +67,28 @@ bool LowerInvoke::runOnFunction(Function &F) { // Remove the invoke instruction now. BB.getInstList().erase(II); - ++NumInvokes; Changed = true; + ++NumInvokes; + Changed = true; } return Changed; } + +bool LowerInvokeLegacyPass::runOnFunction(Function &F) { + return runImpl(F); +} + +namespace llvm { +char &LowerInvokePassID = LowerInvokeLegacyPass::ID; + +// Public Interface To the LowerInvoke pass. +FunctionPass *createLowerInvokePass() { return new LowerInvokeLegacyPass(); } + +PreservedAnalyses LowerInvokePass::run(Function &F, + FunctionAnalysisManager &AM) { + bool Changed = runImpl(F); + if (!Changed) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} +} diff --git a/contrib/llvm/lib/Transforms/Utils/LowerSwitch.cpp b/contrib/llvm/lib/Transforms/Utils/LowerSwitch.cpp index 5c07469..75cd3bc 100644 --- a/contrib/llvm/lib/Transforms/Utils/LowerSwitch.cpp +++ b/contrib/llvm/lib/Transforms/Utils/LowerSwitch.cpp @@ -478,10 +478,10 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, // cases. assert(MaxPop > 0 && PopSucc); Default = PopSucc; - Cases.erase(std::remove_if( - Cases.begin(), Cases.end(), - [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }), - Cases.end()); + Cases.erase( + remove_if(Cases, + [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }), + Cases.end()); // If there are no cases left, just branch. if (Cases.empty()) { diff --git a/contrib/llvm/lib/Transforms/Utils/Mem2Reg.cpp b/contrib/llvm/lib/Transforms/Utils/Mem2Reg.cpp index 1419254..24b3b12 100644 --- a/contrib/llvm/lib/Transforms/Utils/Mem2Reg.cpp +++ b/contrib/llvm/lib/Transforms/Utils/Mem2Reg.cpp @@ -53,7 +53,7 @@ static bool promoteMemoryToRegister(Function &F, DominatorTree &DT, return Changed; } -PreservedAnalyses PromotePass::run(Function &F, AnalysisManager<Function> &AM) { +PreservedAnalyses PromotePass::run(Function &F, FunctionAnalysisManager &AM) { auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); if (!promoteMemoryToRegister(F, DT, AC)) diff --git a/contrib/llvm/lib/Transforms/Utils/MemorySSA.cpp b/contrib/llvm/lib/Transforms/Utils/MemorySSA.cpp index 8ba3cae..1ce4225 100644 --- a/contrib/llvm/lib/Transforms/Utils/MemorySSA.cpp +++ b/contrib/llvm/lib/Transforms/Utils/MemorySSA.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/GraphTraits.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" @@ -60,6 +61,11 @@ INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) INITIALIZE_PASS_END(MemorySSAPrinterLegacyPass, "print-memoryssa", "Memory SSA Printer", false, false) +static cl::opt<unsigned> MaxCheckLimit( + "memssa-check-limit", cl::Hidden, cl::init(100), + cl::desc("The maximum number of stores/phis MemorySSA" + "will consider trying to walk past (default = 100)")); + static cl::opt<bool> VerifyMemorySSA("verify-memoryssa", cl::init(false), cl::Hidden, cl::desc("Verify MemorySSA in legacy printer pass.")); @@ -86,7 +92,963 @@ public: OS << "; " << *MA << "\n"; } }; +} + +namespace { +/// Our current alias analysis API differentiates heavily between calls and +/// non-calls, and functions called on one usually assert on the other. +/// This class encapsulates the distinction to simplify other code that wants +/// "Memory affecting instructions and related data" to use as a key. +/// For example, this class is used as a densemap key in the use optimizer. +class MemoryLocOrCall { +public: + MemoryLocOrCall() : IsCall(false) {} + MemoryLocOrCall(MemoryUseOrDef *MUD) + : MemoryLocOrCall(MUD->getMemoryInst()) {} + MemoryLocOrCall(const MemoryUseOrDef *MUD) + : MemoryLocOrCall(MUD->getMemoryInst()) {} + + MemoryLocOrCall(Instruction *Inst) { + if (ImmutableCallSite(Inst)) { + IsCall = true; + CS = ImmutableCallSite(Inst); + } else { + IsCall = false; + // There is no such thing as a memorylocation for a fence inst, and it is + // unique in that regard. + if (!isa<FenceInst>(Inst)) + Loc = MemoryLocation::get(Inst); + } + } + + explicit MemoryLocOrCall(const MemoryLocation &Loc) + : IsCall(false), Loc(Loc) {} + + bool IsCall; + ImmutableCallSite getCS() const { + assert(IsCall); + return CS; + } + MemoryLocation getLoc() const { + assert(!IsCall); + return Loc; + } + + bool operator==(const MemoryLocOrCall &Other) const { + if (IsCall != Other.IsCall) + return false; + + if (IsCall) + return CS.getCalledValue() == Other.CS.getCalledValue(); + return Loc == Other.Loc; + } + +private: + union { + ImmutableCallSite CS; + MemoryLocation Loc; + }; +}; +} + +namespace llvm { +template <> struct DenseMapInfo<MemoryLocOrCall> { + static inline MemoryLocOrCall getEmptyKey() { + return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getEmptyKey()); + } + static inline MemoryLocOrCall getTombstoneKey() { + return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getTombstoneKey()); + } + static unsigned getHashValue(const MemoryLocOrCall &MLOC) { + if (MLOC.IsCall) + return hash_combine(MLOC.IsCall, + DenseMapInfo<const Value *>::getHashValue( + MLOC.getCS().getCalledValue())); + return hash_combine( + MLOC.IsCall, DenseMapInfo<MemoryLocation>::getHashValue(MLOC.getLoc())); + } + static bool isEqual(const MemoryLocOrCall &LHS, const MemoryLocOrCall &RHS) { + return LHS == RHS; + } +}; + +enum class Reorderability { Always, IfNoAlias, Never }; + +/// This does one-way checks to see if Use could theoretically be hoisted above +/// MayClobber. This will not check the other way around. +/// +/// This assumes that, for the purposes of MemorySSA, Use comes directly after +/// MayClobber, with no potentially clobbering operations in between them. +/// (Where potentially clobbering ops are memory barriers, aliased stores, etc.) +static Reorderability getLoadReorderability(const LoadInst *Use, + const LoadInst *MayClobber) { + bool VolatileUse = Use->isVolatile(); + bool VolatileClobber = MayClobber->isVolatile(); + // Volatile operations may never be reordered with other volatile operations. + if (VolatileUse && VolatileClobber) + return Reorderability::Never; + + // The lang ref allows reordering of volatile and non-volatile operations. + // Whether an aliasing nonvolatile load and volatile load can be reordered, + // though, is ambiguous. Because it may not be best to exploit this ambiguity, + // we only allow volatile/non-volatile reordering if the volatile and + // non-volatile operations don't alias. + Reorderability Result = VolatileUse || VolatileClobber + ? Reorderability::IfNoAlias + : Reorderability::Always; + + // If a load is seq_cst, it cannot be moved above other loads. If its ordering + // is weaker, it can be moved above other loads. We just need to be sure that + // MayClobber isn't an acquire load, because loads can't be moved above + // acquire loads. + // + // Note that this explicitly *does* allow the free reordering of monotonic (or + // weaker) loads of the same address. + bool SeqCstUse = Use->getOrdering() == AtomicOrdering::SequentiallyConsistent; + bool MayClobberIsAcquire = isAtLeastOrStrongerThan(MayClobber->getOrdering(), + AtomicOrdering::Acquire); + if (SeqCstUse || MayClobberIsAcquire) + return Reorderability::Never; + return Result; +} + +static bool instructionClobbersQuery(MemoryDef *MD, + const MemoryLocation &UseLoc, + const Instruction *UseInst, + AliasAnalysis &AA) { + Instruction *DefInst = MD->getMemoryInst(); + assert(DefInst && "Defining instruction not actually an instruction"); + + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(DefInst)) { + // These intrinsics will show up as affecting memory, but they are just + // markers. + switch (II->getIntrinsicID()) { + case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: + case Intrinsic::invariant_start: + case Intrinsic::invariant_end: + case Intrinsic::assume: + return false; + default: + break; + } + } + + ImmutableCallSite UseCS(UseInst); + if (UseCS) { + ModRefInfo I = AA.getModRefInfo(DefInst, UseCS); + return I != MRI_NoModRef; + } + + if (auto *DefLoad = dyn_cast<LoadInst>(DefInst)) { + if (auto *UseLoad = dyn_cast<LoadInst>(UseInst)) { + switch (getLoadReorderability(UseLoad, DefLoad)) { + case Reorderability::Always: + return false; + case Reorderability::Never: + return true; + case Reorderability::IfNoAlias: + return !AA.isNoAlias(UseLoc, MemoryLocation::get(DefLoad)); + } + } + } + + return AA.getModRefInfo(DefInst, UseLoc) & MRI_Mod; +} + +static bool instructionClobbersQuery(MemoryDef *MD, const MemoryUseOrDef *MU, + const MemoryLocOrCall &UseMLOC, + AliasAnalysis &AA) { + // FIXME: This is a temporary hack to allow a single instructionClobbersQuery + // to exist while MemoryLocOrCall is pushed through places. + if (UseMLOC.IsCall) + return instructionClobbersQuery(MD, MemoryLocation(), MU->getMemoryInst(), + AA); + return instructionClobbersQuery(MD, UseMLOC.getLoc(), MU->getMemoryInst(), + AA); +} + +// Return true when MD may alias MU, return false otherwise. +bool defClobbersUseOrDef(MemoryDef *MD, const MemoryUseOrDef *MU, + AliasAnalysis &AA) { + return instructionClobbersQuery(MD, MU, MemoryLocOrCall(MU), AA); +} +} + +namespace { +struct UpwardsMemoryQuery { + // True if our original query started off as a call + bool IsCall; + // The pointer location we started the query with. This will be empty if + // IsCall is true. + MemoryLocation StartingLoc; + // This is the instruction we were querying about. + const Instruction *Inst; + // The MemoryAccess we actually got called with, used to test local domination + const MemoryAccess *OriginalAccess; + + UpwardsMemoryQuery() + : IsCall(false), Inst(nullptr), OriginalAccess(nullptr) {} + + UpwardsMemoryQuery(const Instruction *Inst, const MemoryAccess *Access) + : IsCall(ImmutableCallSite(Inst)), Inst(Inst), OriginalAccess(Access) { + if (!IsCall) + StartingLoc = MemoryLocation::get(Inst); + } +}; + +static bool lifetimeEndsAt(MemoryDef *MD, const MemoryLocation &Loc, + AliasAnalysis &AA) { + Instruction *Inst = MD->getMemoryInst(); + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + switch (II->getIntrinsicID()) { + case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: + return AA.isMustAlias(MemoryLocation(II->getArgOperand(1)), Loc); + default: + return false; + } + } + return false; +} + +static bool isUseTriviallyOptimizableToLiveOnEntry(AliasAnalysis &AA, + const Instruction *I) { + // If the memory can't be changed, then loads of the memory can't be + // clobbered. + // + // FIXME: We should handle invariant groups, as well. It's a bit harder, + // because we need to pay close attention to invariant group barriers. + return isa<LoadInst>(I) && (I->getMetadata(LLVMContext::MD_invariant_load) || + AA.pointsToConstantMemory(I)); +} + +/// Cache for our caching MemorySSA walker. +class WalkerCache { + DenseMap<ConstMemoryAccessPair, MemoryAccess *> Accesses; + DenseMap<const MemoryAccess *, MemoryAccess *> Calls; + +public: + MemoryAccess *lookup(const MemoryAccess *MA, const MemoryLocation &Loc, + bool IsCall) const { + ++NumClobberCacheLookups; + MemoryAccess *R = IsCall ? Calls.lookup(MA) : Accesses.lookup({MA, Loc}); + if (R) + ++NumClobberCacheHits; + return R; + } + + bool insert(const MemoryAccess *MA, MemoryAccess *To, + const MemoryLocation &Loc, bool IsCall) { + // This is fine for Phis, since there are times where we can't optimize + // them. Making a def its own clobber is never correct, though. + assert((MA != To || isa<MemoryPhi>(MA)) && + "Something can't clobber itself!"); + + ++NumClobberCacheInserts; + bool Inserted; + if (IsCall) + Inserted = Calls.insert({MA, To}).second; + else + Inserted = Accesses.insert({{MA, Loc}, To}).second; + + return Inserted; + } + + bool remove(const MemoryAccess *MA, const MemoryLocation &Loc, bool IsCall) { + return IsCall ? Calls.erase(MA) : Accesses.erase({MA, Loc}); + } + + void clear() { + Accesses.clear(); + Calls.clear(); + } + + bool contains(const MemoryAccess *MA) const { + for (auto &P : Accesses) + if (P.first.first == MA || P.second == MA) + return true; + for (auto &P : Calls) + if (P.first == MA || P.second == MA) + return true; + return false; + } +}; + +/// Walks the defining uses of MemoryDefs. Stops after we hit something that has +/// no defining use (e.g. a MemoryPhi or liveOnEntry). Note that, when comparing +/// against a null def_chain_iterator, this will compare equal only after +/// walking said Phi/liveOnEntry. +struct def_chain_iterator + : public iterator_facade_base<def_chain_iterator, std::forward_iterator_tag, + MemoryAccess *> { + def_chain_iterator() : MA(nullptr) {} + def_chain_iterator(MemoryAccess *MA) : MA(MA) {} + + MemoryAccess *operator*() const { return MA; } + + def_chain_iterator &operator++() { + // N.B. liveOnEntry has a null defining access. + if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA)) + MA = MUD->getDefiningAccess(); + else + MA = nullptr; + return *this; + } + + bool operator==(const def_chain_iterator &O) const { return MA == O.MA; } + +private: + MemoryAccess *MA; +}; + +static iterator_range<def_chain_iterator> +def_chain(MemoryAccess *MA, MemoryAccess *UpTo = nullptr) { +#ifdef EXPENSIVE_CHECKS + assert((!UpTo || find(def_chain(MA), UpTo) != def_chain_iterator()) && + "UpTo isn't in the def chain!"); +#endif + return make_range(def_chain_iterator(MA), def_chain_iterator(UpTo)); +} + +/// Verifies that `Start` is clobbered by `ClobberAt`, and that nothing +/// inbetween `Start` and `ClobberAt` can clobbers `Start`. +/// +/// This is meant to be as simple and self-contained as possible. Because it +/// uses no cache, etc., it can be relatively expensive. +/// +/// \param Start The MemoryAccess that we want to walk from. +/// \param ClobberAt A clobber for Start. +/// \param StartLoc The MemoryLocation for Start. +/// \param MSSA The MemorySSA isntance that Start and ClobberAt belong to. +/// \param Query The UpwardsMemoryQuery we used for our search. +/// \param AA The AliasAnalysis we used for our search. +static void LLVM_ATTRIBUTE_UNUSED +checkClobberSanity(MemoryAccess *Start, MemoryAccess *ClobberAt, + const MemoryLocation &StartLoc, const MemorySSA &MSSA, + const UpwardsMemoryQuery &Query, AliasAnalysis &AA) { + assert(MSSA.dominates(ClobberAt, Start) && "Clobber doesn't dominate start?"); + + if (MSSA.isLiveOnEntryDef(Start)) { + assert(MSSA.isLiveOnEntryDef(ClobberAt) && + "liveOnEntry must clobber itself"); + return; + } + + bool FoundClobber = false; + DenseSet<MemoryAccessPair> VisitedPhis; + SmallVector<MemoryAccessPair, 8> Worklist; + Worklist.emplace_back(Start, StartLoc); + // Walk all paths from Start to ClobberAt, while looking for clobbers. If one + // is found, complain. + while (!Worklist.empty()) { + MemoryAccessPair MAP = Worklist.pop_back_val(); + // All we care about is that nothing from Start to ClobberAt clobbers Start. + // We learn nothing from revisiting nodes. + if (!VisitedPhis.insert(MAP).second) + continue; + + for (MemoryAccess *MA : def_chain(MAP.first)) { + if (MA == ClobberAt) { + if (auto *MD = dyn_cast<MemoryDef>(MA)) { + // instructionClobbersQuery isn't essentially free, so don't use `|=`, + // since it won't let us short-circuit. + // + // Also, note that this can't be hoisted out of the `Worklist` loop, + // since MD may only act as a clobber for 1 of N MemoryLocations. + FoundClobber = + FoundClobber || MSSA.isLiveOnEntryDef(MD) || + instructionClobbersQuery(MD, MAP.second, Query.Inst, AA); + } + break; + } + + // We should never hit liveOnEntry, unless it's the clobber. + assert(!MSSA.isLiveOnEntryDef(MA) && "Hit liveOnEntry before clobber?"); + + if (auto *MD = dyn_cast<MemoryDef>(MA)) { + (void)MD; + assert(!instructionClobbersQuery(MD, MAP.second, Query.Inst, AA) && + "Found clobber before reaching ClobberAt!"); + continue; + } + + assert(isa<MemoryPhi>(MA)); + Worklist.append(upward_defs_begin({MA, MAP.second}), upward_defs_end()); + } + } + + // If ClobberAt is a MemoryPhi, we can assume something above it acted as a + // clobber. Otherwise, `ClobberAt` should've acted as a clobber at some point. + assert((isa<MemoryPhi>(ClobberAt) || FoundClobber) && + "ClobberAt never acted as a clobber"); +} + +/// Our algorithm for walking (and trying to optimize) clobbers, all wrapped up +/// in one class. +class ClobberWalker { + /// Save a few bytes by using unsigned instead of size_t. + using ListIndex = unsigned; + + /// Represents a span of contiguous MemoryDefs, potentially ending in a + /// MemoryPhi. + struct DefPath { + MemoryLocation Loc; + // Note that, because we always walk in reverse, Last will always dominate + // First. Also note that First and Last are inclusive. + MemoryAccess *First; + MemoryAccess *Last; + Optional<ListIndex> Previous; + + DefPath(const MemoryLocation &Loc, MemoryAccess *First, MemoryAccess *Last, + Optional<ListIndex> Previous) + : Loc(Loc), First(First), Last(Last), Previous(Previous) {} + + DefPath(const MemoryLocation &Loc, MemoryAccess *Init, + Optional<ListIndex> Previous) + : DefPath(Loc, Init, Init, Previous) {} + }; + + const MemorySSA &MSSA; + AliasAnalysis &AA; + DominatorTree &DT; + WalkerCache &WC; + UpwardsMemoryQuery *Query; + bool UseCache; + + // Phi optimization bookkeeping + SmallVector<DefPath, 32> Paths; + DenseSet<ConstMemoryAccessPair> VisitedPhis; + DenseMap<const BasicBlock *, MemoryAccess *> WalkTargetCache; + + void setUseCache(bool Use) { UseCache = Use; } + bool shouldIgnoreCache() const { + // UseCache will only be false when we're debugging, or when expensive + // checks are enabled. In either case, we don't care deeply about speed. + return LLVM_UNLIKELY(!UseCache); + } + + void addCacheEntry(const MemoryAccess *What, MemoryAccess *To, + const MemoryLocation &Loc) const { +// EXPENSIVE_CHECKS because most of these queries are redundant. +#ifdef EXPENSIVE_CHECKS + assert(MSSA.dominates(To, What)); +#endif + if (shouldIgnoreCache()) + return; + WC.insert(What, To, Loc, Query->IsCall); + } + + MemoryAccess *lookupCache(const MemoryAccess *MA, const MemoryLocation &Loc) { + return shouldIgnoreCache() ? nullptr : WC.lookup(MA, Loc, Query->IsCall); + } + + void cacheDefPath(const DefPath &DN, MemoryAccess *Target) const { + if (shouldIgnoreCache()) + return; + + for (MemoryAccess *MA : def_chain(DN.First, DN.Last)) + addCacheEntry(MA, Target, DN.Loc); + + // DefPaths only express the path we walked. So, DN.Last could either be a + // thing we want to cache, or not. + if (DN.Last != Target) + addCacheEntry(DN.Last, Target, DN.Loc); + } + + /// Find the nearest def or phi that `From` can legally be optimized to. + /// + /// FIXME: Deduplicate this with MSSA::findDominatingDef. Ideally, MSSA should + /// keep track of this information for us, and allow us O(1) lookups of this + /// info. + MemoryAccess *getWalkTarget(const MemoryPhi *From) { + assert(From->getNumOperands() && "Phi with no operands?"); + + BasicBlock *BB = From->getBlock(); + auto At = WalkTargetCache.find(BB); + if (At != WalkTargetCache.end()) + return At->second; + + SmallVector<const BasicBlock *, 8> ToCache; + ToCache.push_back(BB); + + MemoryAccess *Result = MSSA.getLiveOnEntryDef(); + DomTreeNode *Node = DT.getNode(BB); + while ((Node = Node->getIDom())) { + auto At = WalkTargetCache.find(BB); + if (At != WalkTargetCache.end()) { + Result = At->second; + break; + } + + auto *Accesses = MSSA.getBlockAccesses(Node->getBlock()); + if (Accesses) { + auto Iter = find_if(reverse(*Accesses), [](const MemoryAccess &MA) { + return !isa<MemoryUse>(MA); + }); + if (Iter != Accesses->rend()) { + Result = const_cast<MemoryAccess *>(&*Iter); + break; + } + } + + ToCache.push_back(Node->getBlock()); + } + + for (const BasicBlock *BB : ToCache) + WalkTargetCache.insert({BB, Result}); + return Result; + } + + /// Result of calling walkToPhiOrClobber. + struct UpwardsWalkResult { + /// The "Result" of the walk. Either a clobber, the last thing we walked, or + /// both. + MemoryAccess *Result; + bool IsKnownClobber; + bool FromCache; + }; + + /// Walk to the next Phi or Clobber in the def chain starting at Desc.Last. + /// This will update Desc.Last as it walks. It will (optionally) also stop at + /// StopAt. + /// + /// This does not test for whether StopAt is a clobber + UpwardsWalkResult walkToPhiOrClobber(DefPath &Desc, + MemoryAccess *StopAt = nullptr) { + assert(!isa<MemoryUse>(Desc.Last) && "Uses don't exist in my world"); + + for (MemoryAccess *Current : def_chain(Desc.Last)) { + Desc.Last = Current; + if (Current == StopAt) + return {Current, false, false}; + + if (auto *MD = dyn_cast<MemoryDef>(Current)) + if (MSSA.isLiveOnEntryDef(MD) || + instructionClobbersQuery(MD, Desc.Loc, Query->Inst, AA)) + return {MD, true, false}; + + // Cache checks must be done last, because if Current is a clobber, the + // cache will contain the clobber for Current. + if (MemoryAccess *MA = lookupCache(Current, Desc.Loc)) + return {MA, true, true}; + } + + assert(isa<MemoryPhi>(Desc.Last) && + "Ended at a non-clobber that's not a phi?"); + return {Desc.Last, false, false}; + } + + void addSearches(MemoryPhi *Phi, SmallVectorImpl<ListIndex> &PausedSearches, + ListIndex PriorNode) { + auto UpwardDefs = make_range(upward_defs_begin({Phi, Paths[PriorNode].Loc}), + upward_defs_end()); + for (const MemoryAccessPair &P : UpwardDefs) { + PausedSearches.push_back(Paths.size()); + Paths.emplace_back(P.second, P.first, PriorNode); + } + } + + /// Represents a search that terminated after finding a clobber. This clobber + /// may or may not be present in the path of defs from LastNode..SearchStart, + /// since it may have been retrieved from cache. + struct TerminatedPath { + MemoryAccess *Clobber; + ListIndex LastNode; + }; + + /// Get an access that keeps us from optimizing to the given phi. + /// + /// PausedSearches is an array of indices into the Paths array. Its incoming + /// value is the indices of searches that stopped at the last phi optimization + /// target. It's left in an unspecified state. + /// + /// If this returns None, NewPaused is a vector of searches that terminated + /// at StopWhere. Otherwise, NewPaused is left in an unspecified state. + Optional<TerminatedPath> + getBlockingAccess(MemoryAccess *StopWhere, + SmallVectorImpl<ListIndex> &PausedSearches, + SmallVectorImpl<ListIndex> &NewPaused, + SmallVectorImpl<TerminatedPath> &Terminated) { + assert(!PausedSearches.empty() && "No searches to continue?"); + + // BFS vs DFS really doesn't make a difference here, so just do a DFS with + // PausedSearches as our stack. + while (!PausedSearches.empty()) { + ListIndex PathIndex = PausedSearches.pop_back_val(); + DefPath &Node = Paths[PathIndex]; + + // If we've already visited this path with this MemoryLocation, we don't + // need to do so again. + // + // NOTE: That we just drop these paths on the ground makes caching + // behavior sporadic. e.g. given a diamond: + // A + // B C + // D + // + // ...If we walk D, B, A, C, we'll only cache the result of phi + // optimization for A, B, and D; C will be skipped because it dies here. + // This arguably isn't the worst thing ever, since: + // - We generally query things in a top-down order, so if we got below D + // without needing cache entries for {C, MemLoc}, then chances are + // that those cache entries would end up ultimately unused. + // - We still cache things for A, so C only needs to walk up a bit. + // If this behavior becomes problematic, we can fix without a ton of extra + // work. + if (!VisitedPhis.insert({Node.Last, Node.Loc}).second) + continue; + + UpwardsWalkResult Res = walkToPhiOrClobber(Node, /*StopAt=*/StopWhere); + if (Res.IsKnownClobber) { + assert(Res.Result != StopWhere || Res.FromCache); + // If this wasn't a cache hit, we hit a clobber when walking. That's a + // failure. + TerminatedPath Term{Res.Result, PathIndex}; + if (!Res.FromCache || !MSSA.dominates(Res.Result, StopWhere)) + return Term; + + // Otherwise, it's a valid thing to potentially optimize to. + Terminated.push_back(Term); + continue; + } + + if (Res.Result == StopWhere) { + // We've hit our target. Save this path off for if we want to continue + // walking. + NewPaused.push_back(PathIndex); + continue; + } + + assert(!MSSA.isLiveOnEntryDef(Res.Result) && "liveOnEntry is a clobber"); + addSearches(cast<MemoryPhi>(Res.Result), PausedSearches, PathIndex); + } + + return None; + } + + template <typename T, typename Walker> + struct generic_def_path_iterator + : public iterator_facade_base<generic_def_path_iterator<T, Walker>, + std::forward_iterator_tag, T *> { + generic_def_path_iterator() : W(nullptr), N(None) {} + generic_def_path_iterator(Walker *W, ListIndex N) : W(W), N(N) {} + + T &operator*() const { return curNode(); } + + generic_def_path_iterator &operator++() { + N = curNode().Previous; + return *this; + } + + bool operator==(const generic_def_path_iterator &O) const { + if (N.hasValue() != O.N.hasValue()) + return false; + return !N.hasValue() || *N == *O.N; + } + + private: + T &curNode() const { return W->Paths[*N]; } + + Walker *W; + Optional<ListIndex> N; + }; + + using def_path_iterator = generic_def_path_iterator<DefPath, ClobberWalker>; + using const_def_path_iterator = + generic_def_path_iterator<const DefPath, const ClobberWalker>; + + iterator_range<def_path_iterator> def_path(ListIndex From) { + return make_range(def_path_iterator(this, From), def_path_iterator()); + } + + iterator_range<const_def_path_iterator> const_def_path(ListIndex From) const { + return make_range(const_def_path_iterator(this, From), + const_def_path_iterator()); + } + + struct OptznResult { + /// The path that contains our result. + TerminatedPath PrimaryClobber; + /// The paths that we can legally cache back from, but that aren't + /// necessarily the result of the Phi optimization. + SmallVector<TerminatedPath, 4> OtherClobbers; + }; + + ListIndex defPathIndex(const DefPath &N) const { + // The assert looks nicer if we don't need to do &N + const DefPath *NP = &N; + assert(!Paths.empty() && NP >= &Paths.front() && NP <= &Paths.back() && + "Out of bounds DefPath!"); + return NP - &Paths.front(); + } + + /// Try to optimize a phi as best as we can. Returns a SmallVector of Paths + /// that act as legal clobbers. Note that this won't return *all* clobbers. + /// + /// Phi optimization algorithm tl;dr: + /// - Find the earliest def/phi, A, we can optimize to + /// - Find if all paths from the starting memory access ultimately reach A + /// - If not, optimization isn't possible. + /// - Otherwise, walk from A to another clobber or phi, A'. + /// - If A' is a def, we're done. + /// - If A' is a phi, try to optimize it. + /// + /// A path is a series of {MemoryAccess, MemoryLocation} pairs. A path + /// terminates when a MemoryAccess that clobbers said MemoryLocation is found. + OptznResult tryOptimizePhi(MemoryPhi *Phi, MemoryAccess *Start, + const MemoryLocation &Loc) { + assert(Paths.empty() && VisitedPhis.empty() && + "Reset the optimization state."); + + Paths.emplace_back(Loc, Start, Phi, None); + // Stores how many "valid" optimization nodes we had prior to calling + // addSearches/getBlockingAccess. Necessary for caching if we had a blocker. + auto PriorPathsSize = Paths.size(); + + SmallVector<ListIndex, 16> PausedSearches; + SmallVector<ListIndex, 8> NewPaused; + SmallVector<TerminatedPath, 4> TerminatedPaths; + + addSearches(Phi, PausedSearches, 0); + + // Moves the TerminatedPath with the "most dominated" Clobber to the end of + // Paths. + auto MoveDominatedPathToEnd = [&](SmallVectorImpl<TerminatedPath> &Paths) { + assert(!Paths.empty() && "Need a path to move"); + auto Dom = Paths.begin(); + for (auto I = std::next(Dom), E = Paths.end(); I != E; ++I) + if (!MSSA.dominates(I->Clobber, Dom->Clobber)) + Dom = I; + auto Last = Paths.end() - 1; + if (Last != Dom) + std::iter_swap(Last, Dom); + }; + + MemoryPhi *Current = Phi; + while (1) { + assert(!MSSA.isLiveOnEntryDef(Current) && + "liveOnEntry wasn't treated as a clobber?"); + + MemoryAccess *Target = getWalkTarget(Current); + // If a TerminatedPath doesn't dominate Target, then it wasn't a legal + // optimization for the prior phi. + assert(all_of(TerminatedPaths, [&](const TerminatedPath &P) { + return MSSA.dominates(P.Clobber, Target); + })); + + // FIXME: This is broken, because the Blocker may be reported to be + // liveOnEntry, and we'll happily wait for that to disappear (read: never) + // For the moment, this is fine, since we do nothing with blocker info. + if (Optional<TerminatedPath> Blocker = getBlockingAccess( + Target, PausedSearches, NewPaused, TerminatedPaths)) { + // Cache our work on the blocking node, since we know that's correct. + cacheDefPath(Paths[Blocker->LastNode], Blocker->Clobber); + + // Find the node we started at. We can't search based on N->Last, since + // we may have gone around a loop with a different MemoryLocation. + auto Iter = find_if(def_path(Blocker->LastNode), [&](const DefPath &N) { + return defPathIndex(N) < PriorPathsSize; + }); + assert(Iter != def_path_iterator()); + + DefPath &CurNode = *Iter; + assert(CurNode.Last == Current); + + // Two things: + // A. We can't reliably cache all of NewPaused back. Consider a case + // where we have two paths in NewPaused; one of which can't optimize + // above this phi, whereas the other can. If we cache the second path + // back, we'll end up with suboptimal cache entries. We can handle + // cases like this a bit better when we either try to find all + // clobbers that block phi optimization, or when our cache starts + // supporting unfinished searches. + // B. We can't reliably cache TerminatedPaths back here without doing + // extra checks; consider a case like: + // T + // / \ + // D C + // \ / + // S + // Where T is our target, C is a node with a clobber on it, D is a + // diamond (with a clobber *only* on the left or right node, N), and + // S is our start. Say we walk to D, through the node opposite N + // (read: ignoring the clobber), and see a cache entry in the top + // node of D. That cache entry gets put into TerminatedPaths. We then + // walk up to C (N is later in our worklist), find the clobber, and + // quit. If we append TerminatedPaths to OtherClobbers, we'll cache + // the bottom part of D to the cached clobber, ignoring the clobber + // in N. Again, this problem goes away if we start tracking all + // blockers for a given phi optimization. + TerminatedPath Result{CurNode.Last, defPathIndex(CurNode)}; + return {Result, {}}; + } + + // If there's nothing left to search, then all paths led to valid clobbers + // that we got from our cache; pick the nearest to the start, and allow + // the rest to be cached back. + if (NewPaused.empty()) { + MoveDominatedPathToEnd(TerminatedPaths); + TerminatedPath Result = TerminatedPaths.pop_back_val(); + return {Result, std::move(TerminatedPaths)}; + } + + MemoryAccess *DefChainEnd = nullptr; + SmallVector<TerminatedPath, 4> Clobbers; + for (ListIndex Paused : NewPaused) { + UpwardsWalkResult WR = walkToPhiOrClobber(Paths[Paused]); + if (WR.IsKnownClobber) + Clobbers.push_back({WR.Result, Paused}); + else + // Micro-opt: If we hit the end of the chain, save it. + DefChainEnd = WR.Result; + } + + if (!TerminatedPaths.empty()) { + // If we couldn't find the dominating phi/liveOnEntry in the above loop, + // do it now. + if (!DefChainEnd) + for (MemoryAccess *MA : def_chain(Target)) + DefChainEnd = MA; + + // If any of the terminated paths don't dominate the phi we'll try to + // optimize, we need to figure out what they are and quit. + const BasicBlock *ChainBB = DefChainEnd->getBlock(); + for (const TerminatedPath &TP : TerminatedPaths) { + // Because we know that DefChainEnd is as "high" as we can go, we + // don't need local dominance checks; BB dominance is sufficient. + if (DT.dominates(ChainBB, TP.Clobber->getBlock())) + Clobbers.push_back(TP); + } + } + + // If we have clobbers in the def chain, find the one closest to Current + // and quit. + if (!Clobbers.empty()) { + MoveDominatedPathToEnd(Clobbers); + TerminatedPath Result = Clobbers.pop_back_val(); + return {Result, std::move(Clobbers)}; + } + + assert(all_of(NewPaused, + [&](ListIndex I) { return Paths[I].Last == DefChainEnd; })); + + // Because liveOnEntry is a clobber, this must be a phi. + auto *DefChainPhi = cast<MemoryPhi>(DefChainEnd); + + PriorPathsSize = Paths.size(); + PausedSearches.clear(); + for (ListIndex I : NewPaused) + addSearches(DefChainPhi, PausedSearches, I); + NewPaused.clear(); + + Current = DefChainPhi; + } + } + + /// Caches everything in an OptznResult. + void cacheOptResult(const OptznResult &R) { + if (R.OtherClobbers.empty()) { + // If we're not going to be caching OtherClobbers, don't bother with + // marking visited/etc. + for (const DefPath &N : const_def_path(R.PrimaryClobber.LastNode)) + cacheDefPath(N, R.PrimaryClobber.Clobber); + return; + } + + // PrimaryClobber is our answer. If we can cache anything back, we need to + // stop caching when we visit PrimaryClobber. + SmallBitVector Visited(Paths.size()); + for (const DefPath &N : const_def_path(R.PrimaryClobber.LastNode)) { + Visited[defPathIndex(N)] = true; + cacheDefPath(N, R.PrimaryClobber.Clobber); + } + + for (const TerminatedPath &P : R.OtherClobbers) { + for (const DefPath &N : const_def_path(P.LastNode)) { + ListIndex NIndex = defPathIndex(N); + if (Visited[NIndex]) + break; + Visited[NIndex] = true; + cacheDefPath(N, P.Clobber); + } + } + } + + void verifyOptResult(const OptznResult &R) const { + assert(all_of(R.OtherClobbers, [&](const TerminatedPath &P) { + return MSSA.dominates(P.Clobber, R.PrimaryClobber.Clobber); + })); + } + + void resetPhiOptznState() { + Paths.clear(); + VisitedPhis.clear(); + } + +public: + ClobberWalker(const MemorySSA &MSSA, AliasAnalysis &AA, DominatorTree &DT, + WalkerCache &WC) + : MSSA(MSSA), AA(AA), DT(DT), WC(WC), UseCache(true) {} + + void reset() { WalkTargetCache.clear(); } + + /// Finds the nearest clobber for the given query, optimizing phis if + /// possible. + MemoryAccess *findClobber(MemoryAccess *Start, UpwardsMemoryQuery &Q, + bool UseWalkerCache = true) { + setUseCache(UseWalkerCache); + Query = &Q; + + MemoryAccess *Current = Start; + // This walker pretends uses don't exist. If we're handed one, silently grab + // its def. (This has the nice side-effect of ensuring we never cache uses) + if (auto *MU = dyn_cast<MemoryUse>(Start)) + Current = MU->getDefiningAccess(); + + DefPath FirstDesc(Q.StartingLoc, Current, Current, None); + // Fast path for the overly-common case (no crazy phi optimization + // necessary) + UpwardsWalkResult WalkResult = walkToPhiOrClobber(FirstDesc); + MemoryAccess *Result; + if (WalkResult.IsKnownClobber) { + cacheDefPath(FirstDesc, WalkResult.Result); + Result = WalkResult.Result; + } else { + OptznResult OptRes = tryOptimizePhi(cast<MemoryPhi>(FirstDesc.Last), + Current, Q.StartingLoc); + verifyOptResult(OptRes); + cacheOptResult(OptRes); + resetPhiOptznState(); + Result = OptRes.PrimaryClobber.Clobber; + } + +#ifdef EXPENSIVE_CHECKS + checkClobberSanity(Current, Result, Q.StartingLoc, MSSA, Q, AA); +#endif + return Result; + } + + void verify(const MemorySSA *MSSA) { assert(MSSA == &this->MSSA); } +}; + +struct RenamePassData { + DomTreeNode *DTN; + DomTreeNode::const_iterator ChildIt; + MemoryAccess *IncomingVal; + + RenamePassData(DomTreeNode *D, DomTreeNode::const_iterator It, + MemoryAccess *M) + : DTN(D), ChildIt(It), IncomingVal(M) {} + void swap(RenamePassData &RHS) { + std::swap(DTN, RHS.DTN); + std::swap(ChildIt, RHS.ChildIt); + std::swap(IncomingVal, RHS.IncomingVal); + } +}; +} // anonymous namespace +namespace llvm { /// \brief A MemorySSAWalker that does AA walks and caching of lookups to /// disambiguate accesses. /// @@ -121,59 +1083,39 @@ public: /// ret i32 %r /// } class MemorySSA::CachingWalker final : public MemorySSAWalker { + WalkerCache Cache; + ClobberWalker Walker; + bool AutoResetWalker; + + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, UpwardsMemoryQuery &); + void verifyRemoved(MemoryAccess *); + public: CachingWalker(MemorySSA *, AliasAnalysis *, DominatorTree *); ~CachingWalker() override; - MemoryAccess *getClobberingMemoryAccess(const Instruction *) override; + using MemorySSAWalker::getClobberingMemoryAccess; + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *) override; MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, - MemoryLocation &) override; + const MemoryLocation &) override; void invalidateInfo(MemoryAccess *) override; -protected: - struct UpwardsMemoryQuery; - MemoryAccess *doCacheLookup(const MemoryAccess *, const UpwardsMemoryQuery &, - const MemoryLocation &); - - void doCacheInsert(const MemoryAccess *, MemoryAccess *, - const UpwardsMemoryQuery &, const MemoryLocation &); + /// Whether we call resetClobberWalker() after each time we *actually* walk to + /// answer a clobber query. + void setAutoResetWalker(bool AutoReset) { AutoResetWalker = AutoReset; } - void doCacheRemove(const MemoryAccess *, const UpwardsMemoryQuery &, - const MemoryLocation &); + /// Drop the walker's persistent data structures. At the moment, this means + /// "drop the walker's cache of BasicBlocks -> + /// earliest-MemoryAccess-we-can-optimize-to". This is necessary if we're + /// going to have DT updates, if we remove MemoryAccesses, etc. + void resetClobberWalker() { Walker.reset(); } -private: - MemoryAccessPair UpwardsDFSWalk(MemoryAccess *, const MemoryLocation &, - UpwardsMemoryQuery &, bool); - MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, UpwardsMemoryQuery &); - bool instructionClobbersQuery(const MemoryDef *, UpwardsMemoryQuery &, - const MemoryLocation &Loc) const; - void verifyRemoved(MemoryAccess *); - SmallDenseMap<ConstMemoryAccessPair, MemoryAccess *> - CachedUpwardsClobberingAccess; - DenseMap<const MemoryAccess *, MemoryAccess *> CachedUpwardsClobberingCall; - AliasAnalysis *AA; - DominatorTree *DT; -}; -} - -namespace { -struct RenamePassData { - DomTreeNode *DTN; - DomTreeNode::const_iterator ChildIt; - MemoryAccess *IncomingVal; - - RenamePassData(DomTreeNode *D, DomTreeNode::const_iterator It, - MemoryAccess *M) - : DTN(D), ChildIt(It), IncomingVal(M) {} - void swap(RenamePassData &RHS) { - std::swap(DTN, RHS.DTN); - std::swap(ChildIt, RHS.ChildIt); - std::swap(IncomingVal, RHS.IncomingVal); + void verify(const MemorySSA *MSSA) override { + MemorySSAWalker::verify(MSSA); + Walker.verify(MSSA); } }; -} -namespace llvm { /// \brief Rename a single basic block into MemorySSA form. /// Uses the standard SSA renaming algorithm. /// \returns The new incoming value. @@ -184,21 +1126,13 @@ MemoryAccess *MemorySSA::renameBlock(BasicBlock *BB, if (It != PerBlockAccesses.end()) { AccessList *Accesses = It->second.get(); for (MemoryAccess &L : *Accesses) { - switch (L.getValueID()) { - case Value::MemoryUseVal: - cast<MemoryUse>(&L)->setDefiningAccess(IncomingVal); - break; - case Value::MemoryDefVal: - // We can't legally optimize defs, because we only allow single - // memory phis/uses on operations, and if we optimize these, we can - // end up with multiple reaching defs. Uses do not have this - // problem, since they do not produce a value - cast<MemoryDef>(&L)->setDefiningAccess(IncomingVal); + if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(&L)) { + if (MUD->getDefiningAccess() == nullptr) + MUD->setDefiningAccess(IncomingVal); + if (isa<MemoryDef>(&L)) + IncomingVal = &L; + } else { IncomingVal = &L; - break; - case Value::MemoryPhiVal: - IncomingVal = &L; - break; } } } @@ -295,21 +1229,10 @@ void MemorySSA::markUnreachableAsLiveOnEntry(BasicBlock *BB) { MemorySSA::MemorySSA(Function &Func, AliasAnalysis *AA, DominatorTree *DT) : AA(AA), DT(DT), F(Func), LiveOnEntryDef(nullptr), Walker(nullptr), - NextID(0) { + NextID(INVALID_MEMORYACCESS_ID) { buildMemorySSA(); } -MemorySSA::MemorySSA(MemorySSA &&MSSA) - : AA(MSSA.AA), DT(MSSA.DT), F(MSSA.F), - ValueToMemoryAccess(std::move(MSSA.ValueToMemoryAccess)), - PerBlockAccesses(std::move(MSSA.PerBlockAccesses)), - LiveOnEntryDef(std::move(MSSA.LiveOnEntryDef)), - Walker(std::move(MSSA.Walker)), NextID(MSSA.NextID) { - // Update the Walker MSSA pointer so it doesn't point to the moved-from MSSA - // object any more. - Walker->MSSA = this; -} - MemorySSA::~MemorySSA() { // Drop all our references for (const auto &Pair : PerBlockAccesses) @@ -325,6 +1248,245 @@ MemorySSA::AccessList *MemorySSA::getOrCreateAccessList(const BasicBlock *BB) { return Res.first->second.get(); } +/// This class is a batch walker of all MemoryUse's in the program, and points +/// their defining access at the thing that actually clobbers them. Because it +/// is a batch walker that touches everything, it does not operate like the +/// other walkers. This walker is basically performing a top-down SSA renaming +/// pass, where the version stack is used as the cache. This enables it to be +/// significantly more time and memory efficient than using the regular walker, +/// which is walking bottom-up. +class MemorySSA::OptimizeUses { +public: + OptimizeUses(MemorySSA *MSSA, MemorySSAWalker *Walker, AliasAnalysis *AA, + DominatorTree *DT) + : MSSA(MSSA), Walker(Walker), AA(AA), DT(DT) { + Walker = MSSA->getWalker(); + } + + void optimizeUses(); + +private: + /// This represents where a given memorylocation is in the stack. + struct MemlocStackInfo { + // This essentially is keeping track of versions of the stack. Whenever + // the stack changes due to pushes or pops, these versions increase. + unsigned long StackEpoch; + unsigned long PopEpoch; + // This is the lower bound of places on the stack to check. It is equal to + // the place the last stack walk ended. + // Note: Correctness depends on this being initialized to 0, which densemap + // does + unsigned long LowerBound; + const BasicBlock *LowerBoundBlock; + // This is where the last walk for this memory location ended. + unsigned long LastKill; + bool LastKillValid; + }; + void optimizeUsesInBlock(const BasicBlock *, unsigned long &, unsigned long &, + SmallVectorImpl<MemoryAccess *> &, + DenseMap<MemoryLocOrCall, MemlocStackInfo> &); + MemorySSA *MSSA; + MemorySSAWalker *Walker; + AliasAnalysis *AA; + DominatorTree *DT; +}; + +/// Optimize the uses in a given block This is basically the SSA renaming +/// algorithm, with one caveat: We are able to use a single stack for all +/// MemoryUses. This is because the set of *possible* reaching MemoryDefs is +/// the same for every MemoryUse. The *actual* clobbering MemoryDef is just +/// going to be some position in that stack of possible ones. +/// +/// We track the stack positions that each MemoryLocation needs +/// to check, and last ended at. This is because we only want to check the +/// things that changed since last time. The same MemoryLocation should +/// get clobbered by the same store (getModRefInfo does not use invariantness or +/// things like this, and if they start, we can modify MemoryLocOrCall to +/// include relevant data) +void MemorySSA::OptimizeUses::optimizeUsesInBlock( + const BasicBlock *BB, unsigned long &StackEpoch, unsigned long &PopEpoch, + SmallVectorImpl<MemoryAccess *> &VersionStack, + DenseMap<MemoryLocOrCall, MemlocStackInfo> &LocStackInfo) { + + /// If no accesses, nothing to do. + MemorySSA::AccessList *Accesses = MSSA->getWritableBlockAccesses(BB); + if (Accesses == nullptr) + return; + + // Pop everything that doesn't dominate the current block off the stack, + // increment the PopEpoch to account for this. + while (!VersionStack.empty()) { + BasicBlock *BackBlock = VersionStack.back()->getBlock(); + if (DT->dominates(BackBlock, BB)) + break; + while (VersionStack.back()->getBlock() == BackBlock) + VersionStack.pop_back(); + ++PopEpoch; + } + for (MemoryAccess &MA : *Accesses) { + auto *MU = dyn_cast<MemoryUse>(&MA); + if (!MU) { + VersionStack.push_back(&MA); + ++StackEpoch; + continue; + } + + if (isUseTriviallyOptimizableToLiveOnEntry(*AA, MU->getMemoryInst())) { + MU->setDefiningAccess(MSSA->getLiveOnEntryDef(), true); + continue; + } + + MemoryLocOrCall UseMLOC(MU); + auto &LocInfo = LocStackInfo[UseMLOC]; + // If the pop epoch changed, it means we've removed stuff from top of + // stack due to changing blocks. We may have to reset the lower bound or + // last kill info. + if (LocInfo.PopEpoch != PopEpoch) { + LocInfo.PopEpoch = PopEpoch; + LocInfo.StackEpoch = StackEpoch; + // If the lower bound was in something that no longer dominates us, we + // have to reset it. + // We can't simply track stack size, because the stack may have had + // pushes/pops in the meantime. + // XXX: This is non-optimal, but only is slower cases with heavily + // branching dominator trees. To get the optimal number of queries would + // be to make lowerbound and lastkill a per-loc stack, and pop it until + // the top of that stack dominates us. This does not seem worth it ATM. + // A much cheaper optimization would be to always explore the deepest + // branch of the dominator tree first. This will guarantee this resets on + // the smallest set of blocks. + if (LocInfo.LowerBoundBlock && LocInfo.LowerBoundBlock != BB && + !DT->dominates(LocInfo.LowerBoundBlock, BB)) { + // Reset the lower bound of things to check. + // TODO: Some day we should be able to reset to last kill, rather than + // 0. + LocInfo.LowerBound = 0; + LocInfo.LowerBoundBlock = VersionStack[0]->getBlock(); + LocInfo.LastKillValid = false; + } + } else if (LocInfo.StackEpoch != StackEpoch) { + // If all that has changed is the StackEpoch, we only have to check the + // new things on the stack, because we've checked everything before. In + // this case, the lower bound of things to check remains the same. + LocInfo.PopEpoch = PopEpoch; + LocInfo.StackEpoch = StackEpoch; + } + if (!LocInfo.LastKillValid) { + LocInfo.LastKill = VersionStack.size() - 1; + LocInfo.LastKillValid = true; + } + + // At this point, we should have corrected last kill and LowerBound to be + // in bounds. + assert(LocInfo.LowerBound < VersionStack.size() && + "Lower bound out of range"); + assert(LocInfo.LastKill < VersionStack.size() && + "Last kill info out of range"); + // In any case, the new upper bound is the top of the stack. + unsigned long UpperBound = VersionStack.size() - 1; + + if (UpperBound - LocInfo.LowerBound > MaxCheckLimit) { + DEBUG(dbgs() << "MemorySSA skipping optimization of " << *MU << " (" + << *(MU->getMemoryInst()) << ")" + << " because there are " << UpperBound - LocInfo.LowerBound + << " stores to disambiguate\n"); + // Because we did not walk, LastKill is no longer valid, as this may + // have been a kill. + LocInfo.LastKillValid = false; + continue; + } + bool FoundClobberResult = false; + while (UpperBound > LocInfo.LowerBound) { + if (isa<MemoryPhi>(VersionStack[UpperBound])) { + // For phis, use the walker, see where we ended up, go there + Instruction *UseInst = MU->getMemoryInst(); + MemoryAccess *Result = Walker->getClobberingMemoryAccess(UseInst); + // We are guaranteed to find it or something is wrong + while (VersionStack[UpperBound] != Result) { + assert(UpperBound != 0); + --UpperBound; + } + FoundClobberResult = true; + break; + } + + MemoryDef *MD = cast<MemoryDef>(VersionStack[UpperBound]); + // If the lifetime of the pointer ends at this instruction, it's live on + // entry. + if (!UseMLOC.IsCall && lifetimeEndsAt(MD, UseMLOC.getLoc(), *AA)) { + // Reset UpperBound to liveOnEntryDef's place in the stack + UpperBound = 0; + FoundClobberResult = true; + break; + } + if (instructionClobbersQuery(MD, MU, UseMLOC, *AA)) { + FoundClobberResult = true; + break; + } + --UpperBound; + } + // At the end of this loop, UpperBound is either a clobber, or lower bound + // PHI walking may cause it to be < LowerBound, and in fact, < LastKill. + if (FoundClobberResult || UpperBound < LocInfo.LastKill) { + MU->setDefiningAccess(VersionStack[UpperBound], true); + // We were last killed now by where we got to + LocInfo.LastKill = UpperBound; + } else { + // Otherwise, we checked all the new ones, and now we know we can get to + // LastKill. + MU->setDefiningAccess(VersionStack[LocInfo.LastKill], true); + } + LocInfo.LowerBound = VersionStack.size() - 1; + LocInfo.LowerBoundBlock = BB; + } +} + +/// Optimize uses to point to their actual clobbering definitions. +void MemorySSA::OptimizeUses::optimizeUses() { + + // We perform a non-recursive top-down dominator tree walk + struct StackInfo { + const DomTreeNode *Node; + DomTreeNode::const_iterator Iter; + }; + + SmallVector<MemoryAccess *, 16> VersionStack; + SmallVector<StackInfo, 16> DomTreeWorklist; + DenseMap<MemoryLocOrCall, MemlocStackInfo> LocStackInfo; + VersionStack.push_back(MSSA->getLiveOnEntryDef()); + + unsigned long StackEpoch = 1; + unsigned long PopEpoch = 1; + for (const auto *DomNode : depth_first(DT->getRootNode())) + optimizeUsesInBlock(DomNode->getBlock(), StackEpoch, PopEpoch, VersionStack, + LocStackInfo); +} + +void MemorySSA::placePHINodes( + const SmallPtrSetImpl<BasicBlock *> &DefiningBlocks, + const DenseMap<const BasicBlock *, unsigned int> &BBNumbers) { + // Determine where our MemoryPhi's should go + ForwardIDFCalculator IDFs(*DT); + IDFs.setDefiningBlocks(DefiningBlocks); + SmallVector<BasicBlock *, 32> IDFBlocks; + IDFs.calculate(IDFBlocks); + + std::sort(IDFBlocks.begin(), IDFBlocks.end(), + [&BBNumbers](const BasicBlock *A, const BasicBlock *B) { + return BBNumbers.lookup(A) < BBNumbers.lookup(B); + }); + + // Now place MemoryPhi nodes. + for (auto &BB : IDFBlocks) { + // Insert phi node + AccessList *Accesses = getOrCreateAccessList(BB); + MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); + ValueToMemoryAccess[BB] = Phi; + // Phi's always are placed at the front of the block. + Accesses->push_front(Phi); + } +} + void MemorySSA::buildMemorySSA() { // We create an access to represent "live on entry", for things like // arguments or users of globals, where the memory they use is defined before @@ -335,6 +1497,8 @@ void MemorySSA::buildMemorySSA() { BasicBlock &StartingPoint = F.getEntryBlock(); LiveOnEntryDef = make_unique<MemoryDef>(F.getContext(), nullptr, nullptr, &StartingPoint, NextID++); + DenseMap<const BasicBlock *, unsigned int> BBNumbers; + unsigned NextBBNum = 0; // We maintain lists of memory accesses per-block, trading memory for time. We // could just look up the memory access for every possible instruction in the @@ -344,6 +1508,7 @@ void MemorySSA::buildMemorySSA() { // Go through each block, figure out where defs occur, and chain together all // the accesses. for (BasicBlock &B : F) { + BBNumbers[&B] = NextBBNum++; bool InsertIntoDef = false; AccessList *Accesses = nullptr; for (Instruction &I : B) { @@ -361,81 +1526,20 @@ void MemorySSA::buildMemorySSA() { if (Accesses) DefUseBlocks.insert(&B); } - - // Compute live-in. - // Live in is normally defined as "all the blocks on the path from each def to - // each of it's uses". - // MemoryDef's are implicit uses of previous state, so they are also uses. - // This means we don't really have def-only instructions. The only - // MemoryDef's that are not really uses are those that are of the LiveOnEntry - // variable (because LiveOnEntry can reach anywhere, and every def is a - // must-kill of LiveOnEntry). - // In theory, you could precisely compute live-in by using alias-analysis to - // disambiguate defs and uses to see which really pair up with which. - // In practice, this would be really expensive and difficult. So we simply - // assume all defs are also uses that need to be kept live. - // Because of this, the end result of this live-in computation will be "the - // entire set of basic blocks that reach any use". - - SmallPtrSet<BasicBlock *, 32> LiveInBlocks; - SmallVector<BasicBlock *, 64> LiveInBlockWorklist(DefUseBlocks.begin(), - DefUseBlocks.end()); - // Now that we have a set of blocks where a value is live-in, recursively add - // predecessors until we find the full region the value is live. - while (!LiveInBlockWorklist.empty()) { - BasicBlock *BB = LiveInBlockWorklist.pop_back_val(); - - // The block really is live in here, insert it into the set. If already in - // the set, then it has already been processed. - if (!LiveInBlocks.insert(BB).second) - continue; - - // Since the value is live into BB, it is either defined in a predecessor or - // live into it to. - LiveInBlockWorklist.append(pred_begin(BB), pred_end(BB)); - } - - // Determine where our MemoryPhi's should go - ForwardIDFCalculator IDFs(*DT); - IDFs.setDefiningBlocks(DefiningBlocks); - IDFs.setLiveInBlocks(LiveInBlocks); - SmallVector<BasicBlock *, 32> IDFBlocks; - IDFs.calculate(IDFBlocks); - - // Now place MemoryPhi nodes. - for (auto &BB : IDFBlocks) { - // Insert phi node - AccessList *Accesses = getOrCreateAccessList(BB); - MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); - ValueToMemoryAccess.insert(std::make_pair(BB, Phi)); - // Phi's always are placed at the front of the block. - Accesses->push_front(Phi); - } + placePHINodes(DefiningBlocks, BBNumbers); // Now do regular SSA renaming on the MemoryDef/MemoryUse. Visited will get // filled in with all blocks. SmallPtrSet<BasicBlock *, 16> Visited; renamePass(DT->getRootNode(), LiveOnEntryDef.get(), Visited); - MemorySSAWalker *Walker = getWalker(); + CachingWalker *Walker = getWalkerImpl(); - // Now optimize the MemoryUse's defining access to point to the nearest - // dominating clobbering def. - // This ensures that MemoryUse's that are killed by the same store are - // immediate users of that store, one of the invariants we guarantee. - for (auto DomNode : depth_first(DT)) { - BasicBlock *BB = DomNode->getBlock(); - auto AI = PerBlockAccesses.find(BB); - if (AI == PerBlockAccesses.end()) - continue; - AccessList *Accesses = AI->second.get(); - for (auto &MA : *Accesses) { - if (auto *MU = dyn_cast<MemoryUse>(&MA)) { - Instruction *Inst = MU->getMemoryInst(); - MU->setDefiningAccess(Walker->getClobberingMemoryAccess(Inst)); - } - } - } + // We're doing a batch of updates; don't drop useful caches between them. + Walker->setAutoResetWalker(false); + OptimizeUses(this, Walker, AA, DT).optimizeUses(); + Walker->setAutoResetWalker(true); + Walker->resetClobberWalker(); // Mark the uses in unreachable blocks as live on entry, so that they go // somewhere. @@ -444,7 +1548,9 @@ void MemorySSA::buildMemorySSA() { markUnreachableAsLiveOnEntry(&BB); } -MemorySSAWalker *MemorySSA::getWalker() { +MemorySSAWalker *MemorySSA::getWalker() { return getWalkerImpl(); } + +MemorySSA::CachingWalker *MemorySSA::getWalkerImpl() { if (Walker) return Walker.get(); @@ -456,9 +1562,10 @@ MemoryPhi *MemorySSA::createMemoryPhi(BasicBlock *BB) { assert(!getMemoryAccess(BB) && "MemoryPhi already exists for this BB"); AccessList *Accesses = getOrCreateAccessList(BB); MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); - ValueToMemoryAccess.insert(std::make_pair(BB, Phi)); + ValueToMemoryAccess[BB] = Phi; // Phi's always are placed at the front of the block. Accesses->push_front(Phi); + BlockNumberingValid.erase(BB); return Phi; } @@ -481,39 +1588,64 @@ MemoryAccess *MemorySSA::createMemoryAccessInBB(Instruction *I, auto *Accesses = getOrCreateAccessList(BB); if (Point == Beginning) { // It goes after any phi nodes - auto AI = std::find_if( - Accesses->begin(), Accesses->end(), - [](const MemoryAccess &MA) { return !isa<MemoryPhi>(MA); }); + auto AI = find_if( + *Accesses, [](const MemoryAccess &MA) { return !isa<MemoryPhi>(MA); }); Accesses->insert(AI, NewAccess); } else { Accesses->push_back(NewAccess); } - + BlockNumberingValid.erase(BB); return NewAccess; } -MemoryAccess *MemorySSA::createMemoryAccessBefore(Instruction *I, - MemoryAccess *Definition, - MemoryAccess *InsertPt) { + +MemoryUseOrDef *MemorySSA::createMemoryAccessBefore(Instruction *I, + MemoryAccess *Definition, + MemoryUseOrDef *InsertPt) { assert(I->getParent() == InsertPt->getBlock() && "New and old access must be in the same block"); MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); auto *Accesses = getOrCreateAccessList(InsertPt->getBlock()); Accesses->insert(AccessList::iterator(InsertPt), NewAccess); + BlockNumberingValid.erase(InsertPt->getBlock()); return NewAccess; } -MemoryAccess *MemorySSA::createMemoryAccessAfter(Instruction *I, - MemoryAccess *Definition, - MemoryAccess *InsertPt) { +MemoryUseOrDef *MemorySSA::createMemoryAccessAfter(Instruction *I, + MemoryAccess *Definition, + MemoryAccess *InsertPt) { assert(I->getParent() == InsertPt->getBlock() && "New and old access must be in the same block"); MemoryUseOrDef *NewAccess = createDefinedAccess(I, Definition); auto *Accesses = getOrCreateAccessList(InsertPt->getBlock()); Accesses->insertAfter(AccessList::iterator(InsertPt), NewAccess); + BlockNumberingValid.erase(InsertPt->getBlock()); return NewAccess; } +void MemorySSA::spliceMemoryAccessAbove(MemoryDef *Where, + MemoryUseOrDef *What) { + assert(What != getLiveOnEntryDef() && + Where != getLiveOnEntryDef() && "Can't splice (above) LOE."); + assert(dominates(Where, What) && "Only upwards splices are permitted."); + + if (Where == What) + return; + if (isa<MemoryDef>(What)) { + // TODO: possibly use removeMemoryAccess' more efficient RAUW + What->replaceAllUsesWith(What->getDefiningAccess()); + What->setDefiningAccess(Where->getDefiningAccess()); + Where->setDefiningAccess(What); + } + AccessList *Src = getWritableBlockAccesses(What->getBlock()); + AccessList *Dest = getWritableBlockAccesses(Where->getBlock()); + Dest->splice(AccessList::iterator(Where), *Src, What); + + BlockNumberingValid.erase(What->getBlock()); + if (What->getBlock() != Where->getBlock()) + BlockNumberingValid.erase(Where->getBlock()); +} + /// \brief Helper function to create new memory accesses MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I) { // The assume intrinsic has a control dependency which we model by claiming @@ -542,7 +1674,7 @@ MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I) { MUD = new MemoryDef(I->getContext(), nullptr, I, I->getParent(), NextID++); else MUD = new MemoryUse(I->getContext(), nullptr, I, I->getParent()); - ValueToMemoryAccess.insert(std::make_pair(I, MUD)); + ValueToMemoryAccess[I] = MUD; return MUD; } @@ -611,6 +1743,7 @@ static MemoryAccess *onlySingleValue(MemoryPhi *MP) { void MemorySSA::removeFromLookups(MemoryAccess *MA) { assert(MA->use_empty() && "Trying to remove memory access that still has uses"); + BlockNumbering.erase(MA); if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(MA)) MUD->setDefiningAccess(nullptr); // Invalidate our walker's cache if necessary @@ -624,7 +1757,9 @@ void MemorySSA::removeFromLookups(MemoryAccess *MA) { } else { MemoryInst = MA->getBlock(); } - ValueToMemoryAccess.erase(MemoryInst); + auto VMA = ValueToMemoryAccess.find(MemoryInst); + if (VMA->second == MA) + ValueToMemoryAccess.erase(VMA); auto AccessIt = PerBlockAccesses.find(MA->getBlock()); std::unique_ptr<AccessList> &Accesses = AccessIt->second; @@ -652,8 +1787,27 @@ void MemorySSA::removeMemoryAccess(MemoryAccess *MA) { } // Re-point the uses at our defining access - if (!MA->use_empty()) - MA->replaceAllUsesWith(NewDefTarget); + if (!MA->use_empty()) { + // Reset optimized on users of this store, and reset the uses. + // A few notes: + // 1. This is a slightly modified version of RAUW to avoid walking the + // uses twice here. + // 2. If we wanted to be complete, we would have to reset the optimized + // flags on users of phi nodes if doing the below makes a phi node have all + // the same arguments. Instead, we prefer users to removeMemoryAccess those + // phi nodes, because doing it here would be N^3. + if (MA->hasValueHandle()) + ValueHandleBase::ValueIsRAUWd(MA, NewDefTarget); + // Note: We assume MemorySSA is not used in metadata since it's not really + // part of the IR. + + while (!MA->use_empty()) { + Use &U = *MA->use_begin(); + if (MemoryUse *MU = dyn_cast<MemoryUse>(U.getUser())) + MU->resetOptimized(); + U.set(NewDefTarget); + } + } // The call below to erase will destroy MA, so we can't change the order we // are doing things here @@ -674,6 +1828,7 @@ void MemorySSA::verifyMemorySSA() const { verifyDefUses(F); verifyDomination(F); verifyOrdering(F); + Walker->verify(this); } /// \brief Verify that the order and existence of MemoryAccesses matches the @@ -717,70 +1872,38 @@ void MemorySSA::verifyOrdering(Function &F) const { /// \brief Verify the domination properties of MemorySSA by checking that each /// definition dominates all of its uses. void MemorySSA::verifyDomination(Function &F) const { +#ifndef NDEBUG for (BasicBlock &B : F) { // Phi nodes are attached to basic blocks - if (MemoryPhi *MP = getMemoryAccess(&B)) { - for (User *U : MP->users()) { - BasicBlock *UseBlock; - // Phi operands are used on edges, we simulate the right domination by - // acting as if the use occurred at the end of the predecessor block. - if (MemoryPhi *P = dyn_cast<MemoryPhi>(U)) { - for (const auto &Arg : P->operands()) { - if (Arg == MP) { - UseBlock = P->getIncomingBlock(Arg); - break; - } - } - } else { - UseBlock = cast<MemoryAccess>(U)->getBlock(); - } - (void)UseBlock; - assert(DT->dominates(MP->getBlock(), UseBlock) && - "Memory PHI does not dominate it's uses"); - } - } + if (MemoryPhi *MP = getMemoryAccess(&B)) + for (const Use &U : MP->uses()) + assert(dominates(MP, U) && "Memory PHI does not dominate it's uses"); for (Instruction &I : B) { MemoryAccess *MD = dyn_cast_or_null<MemoryDef>(getMemoryAccess(&I)); if (!MD) continue; - for (User *U : MD->users()) { - BasicBlock *UseBlock; - (void)UseBlock; - // Things are allowed to flow to phi nodes over their predecessor edge. - if (auto *P = dyn_cast<MemoryPhi>(U)) { - for (const auto &Arg : P->operands()) { - if (Arg == MD) { - UseBlock = P->getIncomingBlock(Arg); - break; - } - } - } else { - UseBlock = cast<MemoryAccess>(U)->getBlock(); - } - assert(DT->dominates(MD->getBlock(), UseBlock) && - "Memory Def does not dominate it's uses"); - } + for (const Use &U : MD->uses()) + assert(dominates(MD, U) && "Memory Def does not dominate it's uses"); } } +#endif } /// \brief Verify the def-use lists in MemorySSA, by verifying that \p Use /// appears in the use list of \p Def. -/// -/// llvm_unreachable is used instead of asserts because this may be called in -/// a build without asserts. In that case, we don't want this to turn into a -/// nop. + void MemorySSA::verifyUseInDefs(MemoryAccess *Def, MemoryAccess *Use) const { +#ifndef NDEBUG // The live on entry use may cause us to get a NULL def here - if (!Def) { - if (!isLiveOnEntryDef(Use)) - llvm_unreachable("Null def but use not point to live on entry def"); - } else if (std::find(Def->user_begin(), Def->user_end(), Use) == - Def->user_end()) { - llvm_unreachable("Did not find use in def's use list"); - } + if (!Def) + assert(isLiveOnEntryDef(Use) && + "Null def but use not point to live on entry def"); + else + assert(is_contained(Def->users(), Use) && + "Did not find use in def's use list"); +#endif } /// \brief Verify the immediate use information, by walking all the memory @@ -798,21 +1921,35 @@ void MemorySSA::verifyDefUses(Function &F) const { } for (Instruction &I : B) { - if (MemoryAccess *MA = getMemoryAccess(&I)) { - assert(isa<MemoryUseOrDef>(MA) && - "Found a phi node not attached to a bb"); - verifyUseInDefs(cast<MemoryUseOrDef>(MA)->getDefiningAccess(), MA); + if (MemoryUseOrDef *MA = getMemoryAccess(&I)) { + verifyUseInDefs(MA->getDefiningAccess(), MA); } } } } -MemoryAccess *MemorySSA::getMemoryAccess(const Value *I) const { - return ValueToMemoryAccess.lookup(I); +MemoryUseOrDef *MemorySSA::getMemoryAccess(const Instruction *I) const { + return cast_or_null<MemoryUseOrDef>(ValueToMemoryAccess.lookup(I)); } MemoryPhi *MemorySSA::getMemoryAccess(const BasicBlock *BB) const { - return cast_or_null<MemoryPhi>(getMemoryAccess((const Value *)BB)); + return cast_or_null<MemoryPhi>(ValueToMemoryAccess.lookup(cast<Value>(BB))); +} + +/// Perform a local numbering on blocks so that instruction ordering can be +/// determined in constant time. +/// TODO: We currently just number in order. If we numbered by N, we could +/// allow at least N-1 sequences of insertBefore or insertAfter (and at least +/// log2(N) sequences of mixed before and after) without needing to invalidate +/// the numbering. +void MemorySSA::renumberBlock(const BasicBlock *B) const { + // The pre-increment ensures the numbers really start at 1. + unsigned long CurrentNumber = 0; + const AccessList *AL = getBlockAccesses(B); + assert(AL != nullptr && "Asking to renumber an empty block"); + for (const auto &I : *AL) + BlockNumbering[&I] = ++CurrentNumber; + BlockNumberingValid.insert(B); } /// \brief Determine, for two memory accesses in the same block, @@ -821,9 +1958,10 @@ MemoryPhi *MemorySSA::getMemoryAccess(const BasicBlock *BB) const { bool MemorySSA::locallyDominates(const MemoryAccess *Dominator, const MemoryAccess *Dominatee) const { - assert((Dominator->getBlock() == Dominatee->getBlock()) && - "Asking for local domination when accesses are in different blocks!"); + const BasicBlock *DominatorBlock = Dominator->getBlock(); + assert((DominatorBlock == Dominatee->getBlock()) && + "Asking for local domination when accesses are in different blocks!"); // A node dominates itself. if (Dominatee == Dominator) return true; @@ -838,14 +1976,42 @@ bool MemorySSA::locallyDominates(const MemoryAccess *Dominator, if (isLiveOnEntryDef(Dominator)) return true; - // Get the access list for the block - const AccessList *AccessList = getBlockAccesses(Dominator->getBlock()); - AccessList::const_reverse_iterator It(Dominator->getIterator()); + if (!BlockNumberingValid.count(DominatorBlock)) + renumberBlock(DominatorBlock); + + unsigned long DominatorNum = BlockNumbering.lookup(Dominator); + // All numbers start with 1 + assert(DominatorNum != 0 && "Block was not numbered properly"); + unsigned long DominateeNum = BlockNumbering.lookup(Dominatee); + assert(DominateeNum != 0 && "Block was not numbered properly"); + return DominatorNum < DominateeNum; +} + +bool MemorySSA::dominates(const MemoryAccess *Dominator, + const MemoryAccess *Dominatee) const { + if (Dominator == Dominatee) + return true; + + if (isLiveOnEntryDef(Dominatee)) + return false; + + if (Dominator->getBlock() != Dominatee->getBlock()) + return DT->dominates(Dominator->getBlock(), Dominatee->getBlock()); + return locallyDominates(Dominator, Dominatee); +} - // If we hit the beginning of the access list before we hit dominatee, we must - // dominate it - return std::none_of(It, AccessList->rend(), - [&](const MemoryAccess &MA) { return &MA == Dominatee; }); +bool MemorySSA::dominates(const MemoryAccess *Dominator, + const Use &Dominatee) const { + if (MemoryPhi *MP = dyn_cast<MemoryPhi>(Dominatee.getUser())) { + BasicBlock *UseBB = MP->getIncomingBlock(Dominatee); + // The def must dominate the incoming block of the phi. + if (UseBB != Dominator->getBlock()) + return DT->dominates(Dominator->getBlock(), UseBB); + // If the UseBB and the DefBB are the same, compare locally. + return locallyDominates(Dominator, cast<MemoryAccess>(Dominatee)); + } + // If it's not a PHI node use, the normal dominates can already handle it. + return dominates(Dominator, cast<MemoryAccess>(Dominatee.getUser())); } const static char LiveOnEntryStr[] = "liveOnEntry"; @@ -924,25 +2090,26 @@ bool MemorySSAPrinterLegacyPass::runOnFunction(Function &F) { return false; } -char MemorySSAAnalysis::PassID; +AnalysisKey MemorySSAAnalysis::Key; -MemorySSA MemorySSAAnalysis::run(Function &F, AnalysisManager<Function> &AM) { +MemorySSAAnalysis::Result MemorySSAAnalysis::run(Function &F, + FunctionAnalysisManager &AM) { auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); - return MemorySSA(F, &AA, &DT); + return MemorySSAAnalysis::Result(make_unique<MemorySSA>(F, &AA, &DT)); } PreservedAnalyses MemorySSAPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { OS << "MemorySSA for function: " << F.getName() << "\n"; - AM.getResult<MemorySSAAnalysis>(F).print(OS); + AM.getResult<MemorySSAAnalysis>(F).getMSSA().print(OS); return PreservedAnalyses::all(); } PreservedAnalyses MemorySSAVerifierPass::run(Function &F, FunctionAnalysisManager &AM) { - AM.getResult<MemorySSAAnalysis>(F).verifyMemorySSA(); + AM.getResult<MemorySSAAnalysis>(F).getMSSA().verifyMemorySSA(); return PreservedAnalyses::all(); } @@ -978,41 +2145,11 @@ MemorySSAWalker::MemorySSAWalker(MemorySSA *M) : MSSA(M) {} MemorySSA::CachingWalker::CachingWalker(MemorySSA *M, AliasAnalysis *A, DominatorTree *D) - : MemorySSAWalker(M), AA(A), DT(D) {} + : MemorySSAWalker(M), Walker(*M, *A, *D, Cache), AutoResetWalker(true) {} MemorySSA::CachingWalker::~CachingWalker() {} -struct MemorySSA::CachingWalker::UpwardsMemoryQuery { - // True if we saw a phi whose predecessor was a backedge - bool SawBackedgePhi; - // True if our original query started off as a call - bool IsCall; - // The pointer location we started the query with. This will be empty if - // IsCall is true. - MemoryLocation StartingLoc; - // This is the instruction we were querying about. - const Instruction *Inst; - // Set of visited Instructions for this query. - DenseSet<MemoryAccessPair> Visited; - // Vector of visited call accesses for this query. This is separated out - // because you can always cache and lookup the result of call queries (IE when - // IsCall == true) for every call in the chain. The calls have no AA location - // associated with them with them, and thus, no context dependence. - SmallVector<const MemoryAccess *, 32> VisitedCalls; - // The MemoryAccess we actually got called with, used to test local domination - const MemoryAccess *OriginalAccess; - - UpwardsMemoryQuery() - : SawBackedgePhi(false), IsCall(false), Inst(nullptr), - OriginalAccess(nullptr) {} - - UpwardsMemoryQuery(const Instruction *Inst, const MemoryAccess *Access) - : SawBackedgePhi(false), IsCall(ImmutableCallSite(Inst)), Inst(Inst), - OriginalAccess(Access) {} -}; - void MemorySSA::CachingWalker::invalidateInfo(MemoryAccess *MA) { - // TODO: We can do much better cache invalidation with differently stored // caches. For now, for MemoryUses, we simply remove them // from the cache, and kill the entire call/non-call cache for everything @@ -1026,220 +2163,38 @@ void MemorySSA::CachingWalker::invalidateInfo(MemoryAccess *MA) { // itself. if (MemoryUse *MU = dyn_cast<MemoryUse>(MA)) { - UpwardsMemoryQuery Q; - Instruction *I = MU->getMemoryInst(); - Q.IsCall = bool(ImmutableCallSite(I)); - Q.Inst = I; - if (!Q.IsCall) - Q.StartingLoc = MemoryLocation::get(I); - doCacheRemove(MA, Q, Q.StartingLoc); + UpwardsMemoryQuery Q(MU->getMemoryInst(), MU); + Cache.remove(MU, Q.StartingLoc, Q.IsCall); + MU->resetOptimized(); } else { // If it is not a use, the best we can do right now is destroy the cache. - CachedUpwardsClobberingCall.clear(); - CachedUpwardsClobberingAccess.clear(); + Cache.clear(); } #ifdef EXPENSIVE_CHECKS - // Run this only when expensive checks are enabled. verifyRemoved(MA); #endif } -void MemorySSA::CachingWalker::doCacheRemove(const MemoryAccess *M, - const UpwardsMemoryQuery &Q, - const MemoryLocation &Loc) { - if (Q.IsCall) - CachedUpwardsClobberingCall.erase(M); - else - CachedUpwardsClobberingAccess.erase({M, Loc}); -} - -void MemorySSA::CachingWalker::doCacheInsert(const MemoryAccess *M, - MemoryAccess *Result, - const UpwardsMemoryQuery &Q, - const MemoryLocation &Loc) { - // This is fine for Phis, since there are times where we can't optimize them. - // Making a def its own clobber is never correct, though. - assert((Result != M || isa<MemoryPhi>(M)) && - "Something can't clobber itself!"); - ++NumClobberCacheInserts; - if (Q.IsCall) - CachedUpwardsClobberingCall[M] = Result; - else - CachedUpwardsClobberingAccess[{M, Loc}] = Result; -} - -MemoryAccess * -MemorySSA::CachingWalker::doCacheLookup(const MemoryAccess *M, - const UpwardsMemoryQuery &Q, - const MemoryLocation &Loc) { - ++NumClobberCacheLookups; - MemoryAccess *Result; - - if (Q.IsCall) - Result = CachedUpwardsClobberingCall.lookup(M); - else - Result = CachedUpwardsClobberingAccess.lookup({M, Loc}); - - if (Result) - ++NumClobberCacheHits; - return Result; -} - -bool MemorySSA::CachingWalker::instructionClobbersQuery( - const MemoryDef *MD, UpwardsMemoryQuery &Q, - const MemoryLocation &Loc) const { - Instruction *DefMemoryInst = MD->getMemoryInst(); - assert(DefMemoryInst && "Defining instruction not actually an instruction"); - - if (!Q.IsCall) - return AA->getModRefInfo(DefMemoryInst, Loc) & MRI_Mod; - - // If this is a call, mark it for caching - if (ImmutableCallSite(DefMemoryInst)) - Q.VisitedCalls.push_back(MD); - ModRefInfo I = AA->getModRefInfo(DefMemoryInst, ImmutableCallSite(Q.Inst)); - return I != MRI_NoModRef; -} - -MemoryAccessPair MemorySSA::CachingWalker::UpwardsDFSWalk( - MemoryAccess *StartingAccess, const MemoryLocation &Loc, - UpwardsMemoryQuery &Q, bool FollowingBackedge) { - MemoryAccess *ModifyingAccess = nullptr; - - auto DFI = df_begin(StartingAccess); - for (auto DFE = df_end(StartingAccess); DFI != DFE;) { - MemoryAccess *CurrAccess = *DFI; - if (MSSA->isLiveOnEntryDef(CurrAccess)) - return {CurrAccess, Loc}; - // If this is a MemoryDef, check whether it clobbers our current query. This - // needs to be done before consulting the cache, because the cache reports - // the clobber for CurrAccess. If CurrAccess is a clobber for this query, - // and we ask the cache for information first, then we might skip this - // clobber, which is bad. - if (auto *MD = dyn_cast<MemoryDef>(CurrAccess)) { - // If we hit the top, stop following this path. - // While we can do lookups, we can't sanely do inserts here unless we were - // to track everything we saw along the way, since we don't know where we - // will stop. - if (instructionClobbersQuery(MD, Q, Loc)) { - ModifyingAccess = CurrAccess; - break; - } - } - if (auto CacheResult = doCacheLookup(CurrAccess, Q, Loc)) - return {CacheResult, Loc}; - - // We need to know whether it is a phi so we can track backedges. - // Otherwise, walk all upward defs. - if (!isa<MemoryPhi>(CurrAccess)) { - ++DFI; - continue; - } - -#ifndef NDEBUG - // The loop below visits the phi's children for us. Because phis are the - // only things with multiple edges, skipping the children should always lead - // us to the end of the loop. - // - // Use a copy of DFI because skipChildren would kill our search stack, which - // would make caching anything on the way back impossible. - auto DFICopy = DFI; - assert(DFICopy.skipChildren() == DFE && - "Skipping phi's children doesn't end the DFS?"); -#endif - - const MemoryAccessPair PHIPair(CurrAccess, Loc); - - // Don't try to optimize this phi again if we've already tried to do so. - if (!Q.Visited.insert(PHIPair).second) { - ModifyingAccess = CurrAccess; - break; - } - - std::size_t InitialVisitedCallSize = Q.VisitedCalls.size(); - - // Recurse on PHI nodes, since we need to change locations. - // TODO: Allow graphtraits on pairs, which would turn this whole function - // into a normal single depth first walk. - MemoryAccess *FirstDef = nullptr; - for (auto MPI = upward_defs_begin(PHIPair), MPE = upward_defs_end(); - MPI != MPE; ++MPI) { - bool Backedge = - !FollowingBackedge && - DT->dominates(CurrAccess->getBlock(), MPI.getPhiArgBlock()); - - MemoryAccessPair CurrentPair = - UpwardsDFSWalk(MPI->first, MPI->second, Q, Backedge); - // All the phi arguments should reach the same point if we can bypass - // this phi. The alternative is that they hit this phi node, which - // means we can skip this argument. - if (FirstDef && CurrentPair.first != PHIPair.first && - CurrentPair.first != FirstDef) { - ModifyingAccess = CurrAccess; - break; - } - - if (!FirstDef) - FirstDef = CurrentPair.first; - } - - // If we exited the loop early, go with the result it gave us. - if (!ModifyingAccess) { - assert(FirstDef && "Found a Phi with no upward defs?"); - ModifyingAccess = FirstDef; - } else { - // If we can't optimize this Phi, then we can't safely cache any of the - // calls we visited when trying to optimize it. Wipe them out now. - Q.VisitedCalls.resize(InitialVisitedCallSize); - } - break; - } - - if (!ModifyingAccess) - return {MSSA->getLiveOnEntryDef(), Q.StartingLoc}; - - const BasicBlock *OriginalBlock = StartingAccess->getBlock(); - assert(DFI.getPathLength() > 0 && "We dropped our path?"); - unsigned N = DFI.getPathLength(); - // If we found a clobbering def, the last element in the path will be our - // clobber, so we don't want to cache that to itself. OTOH, if we optimized a - // phi, we can add the last thing in the path to the cache, since that won't - // be the result. - if (DFI.getPath(N - 1) == ModifyingAccess) - --N; - for (; N > 1; --N) { - MemoryAccess *CacheAccess = DFI.getPath(N - 1); - BasicBlock *CurrBlock = CacheAccess->getBlock(); - if (!FollowingBackedge) - doCacheInsert(CacheAccess, ModifyingAccess, Q, Loc); - if (DT->dominates(CurrBlock, OriginalBlock) && - (CurrBlock != OriginalBlock || !FollowingBackedge || - MSSA->locallyDominates(CacheAccess, StartingAccess))) - break; - } - - // Cache everything else on the way back. The caller should cache - // StartingAccess for us. - for (; N > 1; --N) { - MemoryAccess *CacheAccess = DFI.getPath(N - 1); - doCacheInsert(CacheAccess, ModifyingAccess, Q, Loc); - } - - return {ModifyingAccess, Loc}; -} - /// \brief Walk the use-def chains starting at \p MA and find /// the MemoryAccess that actually clobbers Loc. /// /// \returns our clobbering memory access MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( MemoryAccess *StartingAccess, UpwardsMemoryQuery &Q) { - return UpwardsDFSWalk(StartingAccess, Q.StartingLoc, Q, false).first; + MemoryAccess *New = Walker.findClobber(StartingAccess, Q); +#ifdef EXPENSIVE_CHECKS + MemoryAccess *NewNoCache = + Walker.findClobber(StartingAccess, Q, /*UseWalkerCache=*/false); + assert(NewNoCache == New && "Cache made us hand back a different result?"); +#endif + if (AutoResetWalker) + resetClobberWalker(); + return New; } MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( - MemoryAccess *StartingAccess, MemoryLocation &Loc) { + MemoryAccess *StartingAccess, const MemoryLocation &Loc) { if (isa<MemoryPhi>(StartingAccess)) return StartingAccess; @@ -1257,10 +2212,10 @@ MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( UpwardsMemoryQuery Q; Q.OriginalAccess = StartingUseOrDef; Q.StartingLoc = Loc; - Q.Inst = StartingUseOrDef->getMemoryInst(); + Q.Inst = I; Q.IsCall = false; - if (auto CacheResult = doCacheLookup(StartingUseOrDef, Q, Q.StartingLoc)) + if (auto *CacheResult = Cache.lookup(StartingUseOrDef, Loc, Q.IsCall)) return CacheResult; // Unlike the other function, do not walk to the def of a def, because we are @@ -1270,9 +2225,6 @@ MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( : StartingUseOrDef; MemoryAccess *Clobber = getClobberingMemoryAccess(DefiningAccess, Q); - // Only cache this if it wouldn't make Clobber point to itself. - if (Clobber != StartingAccess) - doCacheInsert(Q.OriginalAccess, Clobber, Q, Q.StartingLoc); DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); DEBUG(dbgs() << *StartingUseOrDef << "\n"); DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); @@ -1281,28 +2233,38 @@ MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( } MemoryAccess * -MemorySSA::CachingWalker::getClobberingMemoryAccess(const Instruction *I) { - // There should be no way to lookup an instruction and get a phi as the - // access, since we only map BB's to PHI's. So, this must be a use or def. - auto *StartingAccess = cast<MemoryUseOrDef>(MSSA->getMemoryAccess(I)); - - bool IsCall = bool(ImmutableCallSite(I)); - +MemorySSA::CachingWalker::getClobberingMemoryAccess(MemoryAccess *MA) { + auto *StartingAccess = dyn_cast<MemoryUseOrDef>(MA); + // If this is a MemoryPhi, we can't do anything. + if (!StartingAccess) + return MA; + + // If this is an already optimized use or def, return the optimized result. + // Note: Currently, we do not store the optimized def result because we'd need + // a separate field, since we can't use it as the defining access. + if (MemoryUse *MU = dyn_cast<MemoryUse>(StartingAccess)) + if (MU->isOptimized()) + return MU->getDefiningAccess(); + + const Instruction *I = StartingAccess->getMemoryInst(); + UpwardsMemoryQuery Q(I, StartingAccess); // We can't sanely do anything with a fences, they conservatively // clobber all memory, and have no locations to get pointers from to // try to disambiguate. - if (!IsCall && I->isFenceLike()) + if (!Q.IsCall && I->isFenceLike()) return StartingAccess; - UpwardsMemoryQuery Q; - Q.OriginalAccess = StartingAccess; - Q.IsCall = IsCall; - if (!Q.IsCall) - Q.StartingLoc = MemoryLocation::get(I); - Q.Inst = I; - if (auto CacheResult = doCacheLookup(StartingAccess, Q, Q.StartingLoc)) + if (auto *CacheResult = Cache.lookup(StartingAccess, Q.StartingLoc, Q.IsCall)) return CacheResult; + if (isUseTriviallyOptimizableToLiveOnEntry(*MSSA->AA, I)) { + MemoryAccess *LiveOnEntry = MSSA->getLiveOnEntryDef(); + Cache.insert(StartingAccess, LiveOnEntry, Q.StartingLoc, Q.IsCall); + if (MemoryUse *MU = dyn_cast<MemoryUse>(StartingAccess)) + MU->setDefiningAccess(LiveOnEntry, true); + return LiveOnEntry; + } + // Start with the thing we already think clobbers this location MemoryAccess *DefiningAccess = StartingAccess->getDefiningAccess(); @@ -1312,50 +2274,32 @@ MemorySSA::CachingWalker::getClobberingMemoryAccess(const Instruction *I) { return DefiningAccess; MemoryAccess *Result = getClobberingMemoryAccess(DefiningAccess, Q); - // DFS won't cache a result for DefiningAccess. So, if DefiningAccess isn't - // our clobber, be sure that it gets a cache entry, too. - if (Result != DefiningAccess) - doCacheInsert(DefiningAccess, Result, Q, Q.StartingLoc); - doCacheInsert(Q.OriginalAccess, Result, Q, Q.StartingLoc); - // TODO: When this implementation is more mature, we may want to figure out - // what this additional caching buys us. It's most likely A Good Thing. - if (Q.IsCall) - for (const MemoryAccess *MA : Q.VisitedCalls) - if (MA != Result) - doCacheInsert(MA, Result, Q, Q.StartingLoc); - DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); DEBUG(dbgs() << *DefiningAccess << "\n"); DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); DEBUG(dbgs() << *Result << "\n"); + if (MemoryUse *MU = dyn_cast<MemoryUse>(StartingAccess)) + MU->setDefiningAccess(Result, true); return Result; } // Verify that MA doesn't exist in any of the caches. void MemorySSA::CachingWalker::verifyRemoved(MemoryAccess *MA) { -#ifndef NDEBUG - for (auto &P : CachedUpwardsClobberingAccess) - assert(P.first.first != MA && P.second != MA && - "Found removed MemoryAccess in cache."); - for (auto &P : CachedUpwardsClobberingCall) - assert(P.first != MA && P.second != MA && - "Found removed MemoryAccess in cache."); -#endif // !NDEBUG + assert(!Cache.contains(MA) && "Found removed MemoryAccess in cache."); } MemoryAccess * -DoNothingMemorySSAWalker::getClobberingMemoryAccess(const Instruction *I) { - MemoryAccess *MA = MSSA->getMemoryAccess(I); +DoNothingMemorySSAWalker::getClobberingMemoryAccess(MemoryAccess *MA) { if (auto *Use = dyn_cast<MemoryUseOrDef>(MA)) return Use->getDefiningAccess(); return MA; } MemoryAccess *DoNothingMemorySSAWalker::getClobberingMemoryAccess( - MemoryAccess *StartingAccess, MemoryLocation &) { + MemoryAccess *StartingAccess, const MemoryLocation &) { if (auto *Use = dyn_cast<MemoryUseOrDef>(StartingAccess)) return Use->getDefiningAccess(); return StartingAccess; } -} +} // namespace llvm diff --git a/contrib/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/contrib/llvm/lib/Transforms/Utils/ModuleUtils.cpp index eb91885..0d623df 100644 --- a/contrib/llvm/lib/Transforms/Utils/ModuleUtils.cpp +++ b/contrib/llvm/lib/Transforms/Utils/ModuleUtils.cpp @@ -89,6 +89,44 @@ void llvm::appendToGlobalDtors(Module &M, Function *F, int Priority, Constant *D appendToGlobalArray("llvm.global_dtors", M, F, Priority, Data); } +static void appendToUsedList(Module &M, StringRef Name, ArrayRef<GlobalValue *> Values) { + GlobalVariable *GV = M.getGlobalVariable(Name); + SmallPtrSet<Constant *, 16> InitAsSet; + SmallVector<Constant *, 16> Init; + if (GV) { + ConstantArray *CA = dyn_cast<ConstantArray>(GV->getInitializer()); + for (auto &Op : CA->operands()) { + Constant *C = cast_or_null<Constant>(Op); + if (InitAsSet.insert(C).second) + Init.push_back(C); + } + GV->eraseFromParent(); + } + + Type *Int8PtrTy = llvm::Type::getInt8PtrTy(M.getContext()); + for (auto *V : Values) { + Constant *C = ConstantExpr::getBitCast(V, Int8PtrTy); + if (InitAsSet.insert(C).second) + Init.push_back(C); + } + + if (Init.empty()) + return; + + ArrayType *ATy = ArrayType::get(Int8PtrTy, Init.size()); + GV = new llvm::GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage, + ConstantArray::get(ATy, Init), Name); + GV->setSection("llvm.metadata"); +} + +void llvm::appendToUsed(Module &M, ArrayRef<GlobalValue *> Values) { + appendToUsedList(M, "llvm.used", Values); +} + +void llvm::appendToCompilerUsed(Module &M, ArrayRef<GlobalValue *> Values) { + appendToUsedList(M, "llvm.compiler.used", Values); +} + Function *llvm::checkSanitizerInterfaceFunction(Constant *FuncOrBitcast) { if (isa<Function>(FuncOrBitcast)) return cast<Function>(FuncOrBitcast); @@ -104,7 +142,7 @@ std::pair<Function *, Function *> llvm::createSanitizerCtorAndInitFunctions( ArrayRef<Type *> InitArgTypes, ArrayRef<Value *> InitArgs, StringRef VersionCheckName) { assert(!InitName.empty() && "Expected init function name"); - assert(InitArgTypes.size() == InitArgTypes.size() && + assert(InitArgs.size() == InitArgTypes.size() && "Sanitizer's init function expects different number of arguments"); Function *Ctor = Function::Create( FunctionType::get(Type::getVoidTy(M.getContext()), false), @@ -126,3 +164,67 @@ std::pair<Function *, Function *> llvm::createSanitizerCtorAndInitFunctions( } return std::make_pair(Ctor, InitFunction); } + +void llvm::filterDeadComdatFunctions( + Module &M, SmallVectorImpl<Function *> &DeadComdatFunctions) { + // Build a map from the comdat to the number of entries in that comdat we + // think are dead. If this fully covers the comdat group, then the entire + // group is dead. If we find another entry in the comdat group though, we'll + // have to preserve the whole group. + SmallDenseMap<Comdat *, int, 16> ComdatEntriesCovered; + for (Function *F : DeadComdatFunctions) { + Comdat *C = F->getComdat(); + assert(C && "Expected all input GVs to be in a comdat!"); + ComdatEntriesCovered[C] += 1; + } + + auto CheckComdat = [&](Comdat &C) { + auto CI = ComdatEntriesCovered.find(&C); + if (CI == ComdatEntriesCovered.end()) + return; + + // If this could have been covered by a dead entry, just subtract one to + // account for it. + if (CI->second > 0) { + CI->second -= 1; + return; + } + + // If we've already accounted for all the entries that were dead, the + // entire comdat is alive so remove it from the map. + ComdatEntriesCovered.erase(CI); + }; + + auto CheckAllComdats = [&] { + for (Function &F : M.functions()) + if (Comdat *C = F.getComdat()) { + CheckComdat(*C); + if (ComdatEntriesCovered.empty()) + return; + } + for (GlobalVariable &GV : M.globals()) + if (Comdat *C = GV.getComdat()) { + CheckComdat(*C); + if (ComdatEntriesCovered.empty()) + return; + } + for (GlobalAlias &GA : M.aliases()) + if (Comdat *C = GA.getComdat()) { + CheckComdat(*C); + if (ComdatEntriesCovered.empty()) + return; + } + }; + CheckAllComdats(); + + if (ComdatEntriesCovered.empty()) { + DeadComdatFunctions.clear(); + return; + } + + // Remove the entries that were not covering. + erase_if(DeadComdatFunctions, [&](GlobalValue *GV) { + return ComdatEntriesCovered.find(GV->getComdat()) == + ComdatEntriesCovered.end(); + }); +} diff --git a/contrib/llvm/lib/Transforms/Utils/NameAnonFunctions.cpp b/contrib/llvm/lib/Transforms/Utils/NameAnonGlobals.cpp index c4f3839..34dc1cc 100644 --- a/contrib/llvm/lib/Transforms/Utils/NameAnonFunctions.cpp +++ b/contrib/llvm/lib/Transforms/Utils/NameAnonGlobals.cpp @@ -1,4 +1,4 @@ -//===- NameAnonFunctions.cpp - ThinLTO Summary-based Function Import ------===// +//===- NameAnonGlobals.cpp - ThinLTO Support: Name Unnamed Globals --------===// // // The LLVM Compiler Infrastructure // @@ -7,11 +7,13 @@ // //===----------------------------------------------------------------------===// // -// This file implements naming anonymous function to make sure they can be -// refered to by ThinLTO. +// This file implements naming anonymous globals to make sure they can be +// referred to by ThinLTO. // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Utils/NameAnonGlobals.h" + #include "llvm/ADT/SmallString.h" #include "llvm/IR/Module.h" #include "llvm/Support/MD5.h" @@ -19,8 +21,9 @@ using namespace llvm; +namespace { // Compute a "unique" hash for the module based on the name of the public -// functions. +// globals. class ModuleHasher { Module &TheModule; std::string TheHash; @@ -57,46 +60,62 @@ public: return TheHash; } }; +} // end anonymous namespace -// Rename all the anon functions in the module -bool llvm::nameUnamedFunctions(Module &M) { +// Rename all the anon globals in the module +bool llvm::nameUnamedGlobals(Module &M) { bool Changed = false; ModuleHasher ModuleHash(M); int count = 0; - for (auto &F : M) { - if (F.hasName()) - continue; - F.setName(Twine("anon.") + ModuleHash.get() + "." + Twine(count++)); + auto RenameIfNeed = [&](GlobalValue &GV) { + if (GV.hasName()) + return; + GV.setName(Twine("anon.") + ModuleHash.get() + "." + Twine(count++)); Changed = true; - } + }; + for (auto &GO : M.global_objects()) + RenameIfNeed(GO); + for (auto &GA : M.aliases()) + RenameIfNeed(GA); + return Changed; } namespace { -// Simple pass that provides a name to every anon function. -class NameAnonFunction : public ModulePass { +// Legacy pass that provides a name to every anon globals. +class NameAnonGlobalLegacyPass : public ModulePass { public: /// Pass identification, replacement for typeid static char ID; /// Specify pass name for debug output - const char *getPassName() const override { return "Name Anon Functions"; } + StringRef getPassName() const override { return "Name Anon Globals"; } - explicit NameAnonFunction() : ModulePass(ID) {} + explicit NameAnonGlobalLegacyPass() : ModulePass(ID) {} - bool runOnModule(Module &M) override { return nameUnamedFunctions(M); } + bool runOnModule(Module &M) override { return nameUnamedGlobals(M); } }; -char NameAnonFunction::ID = 0; +char NameAnonGlobalLegacyPass::ID = 0; } // anonymous namespace -INITIALIZE_PASS_BEGIN(NameAnonFunction, "name-anon-functions", - "Provide a name to nameless functions", false, false) -INITIALIZE_PASS_END(NameAnonFunction, "name-anon-functions", - "Provide a name to nameless functions", false, false) +PreservedAnalyses NameAnonGlobalPass::run(Module &M, + ModuleAnalysisManager &AM) { + if (!nameUnamedGlobals(M)) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} + +INITIALIZE_PASS_BEGIN(NameAnonGlobalLegacyPass, "name-anon-globals", + "Provide a name to nameless globals", false, false) +INITIALIZE_PASS_END(NameAnonGlobalLegacyPass, "name-anon-globals", + "Provide a name to nameless globals", false, false) namespace llvm { -ModulePass *createNameAnonFunctionPass() { return new NameAnonFunction(); } +ModulePass *createNameAnonGlobalPass() { + return new NameAnonGlobalLegacyPass(); +} } diff --git a/contrib/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/contrib/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index cbf385d..35faa6f 100644 --- a/contrib/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/contrib/llvm/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -907,6 +907,8 @@ NextIteration: // The currently active variable for this block is now the PHI. IncomingVals[AllocaNo] = APN; + if (DbgDeclareInst *DDI = AllocaDbgDeclares[AllocaNo]) + ConvertDebugDeclareToDebugValue(DDI, APN, DIB); // Get the next phi node. ++PNI; diff --git a/contrib/llvm/lib/Transforms/Utils/SSAUpdater.cpp b/contrib/llvm/lib/Transforms/Utils/SSAUpdater.cpp index 88b39dd..8e93ee7 100644 --- a/contrib/llvm/lib/Transforms/Utils/SSAUpdater.cpp +++ b/contrib/llvm/lib/Transforms/Utils/SSAUpdater.cpp @@ -482,5 +482,5 @@ bool LoadAndStorePromoter::isInstInList(Instruction *I, const SmallVectorImpl<Instruction*> &Insts) const { - return std::find(Insts.begin(), Insts.end(), I) != Insts.end(); + return is_contained(Insts, I); } diff --git a/contrib/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/contrib/llvm/lib/Transforms/Utils/SimplifyCFG.cpp index c197317..7b0bddb 100644 --- a/contrib/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/contrib/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -11,27 +11,39 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" @@ -40,15 +52,29 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/DebugInfo.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> +#include <cassert> +#include <climits> +#include <cstddef> +#include <cstdint> +#include <iterator> #include <map> #include <set> +#include <utility> +#include <vector> + using namespace llvm; using namespace PatternMatch; @@ -110,6 +136,7 @@ STATISTIC(NumSinkCommons, STATISTIC(NumSpeculations, "Number of speculative executed instructions"); namespace { + // The first field contains the value that the switch produces when a certain // case group is selected, and the second field is a vector containing the // cases composing the case group. @@ -168,13 +195,17 @@ public: SmallPtrSetImpl<BasicBlock *> *LoopHeaders) : TTI(TTI), DL(DL), BonusInstThreshold(BonusInstThreshold), AC(AC), LoopHeaders(LoopHeaders) {} + bool run(BasicBlock *BB); }; -} + +} // end anonymous namespace /// Return true if it is safe to merge these two /// terminator instructions together. -static bool SafeToMergeTerminators(TerminatorInst *SI1, TerminatorInst *SI2) { +static bool +SafeToMergeTerminators(TerminatorInst *SI1, TerminatorInst *SI2, + SmallSetVector<BasicBlock *, 4> *FailBlocks = nullptr) { if (SI1 == SI2) return false; // Can't merge with self! @@ -183,18 +214,22 @@ static bool SafeToMergeTerminators(TerminatorInst *SI1, TerminatorInst *SI2) { // conflicting incoming values from the two switch blocks. BasicBlock *SI1BB = SI1->getParent(); BasicBlock *SI2BB = SI2->getParent(); - SmallPtrSet<BasicBlock *, 16> SI1Succs(succ_begin(SI1BB), succ_end(SI1BB)); + SmallPtrSet<BasicBlock *, 16> SI1Succs(succ_begin(SI1BB), succ_end(SI1BB)); + bool Fail = false; for (BasicBlock *Succ : successors(SI2BB)) if (SI1Succs.count(Succ)) for (BasicBlock::iterator BBI = Succ->begin(); isa<PHINode>(BBI); ++BBI) { PHINode *PN = cast<PHINode>(BBI); if (PN->getIncomingValueForBlock(SI1BB) != - PN->getIncomingValueForBlock(SI2BB)) - return false; + PN->getIncomingValueForBlock(SI2BB)) { + if (FailBlocks) + FailBlocks->insert(Succ); + Fail = true; + } } - return true; + return !Fail; } /// Return true if it is safe and profitable to merge these two terminator @@ -621,7 +656,8 @@ private: } } }; -} + +} // end anonymous namespace static void EraseTerminatorInstAndDCECond(TerminatorInst *TI) { Instruction *Cond = nullptr; @@ -706,7 +742,7 @@ static bool ValuesOverlap(std::vector<ValueEqualityComparisonCase> &C1, if (V1->size() > V2->size()) std::swap(V1, V2); - if (V1->size() == 0) + if (V1->empty()) return false; if (V1->size() == 1) { // Just scan V2. @@ -874,6 +910,7 @@ bool SimplifyCFGOpt::SimplifyEqualityComparisonWithOnlyPredecessor( } namespace { + /// This class implements a stable ordering of constant /// integers that does not depend on their address. This is important for /// applications that sort ConstantInt's to ensure uniqueness. @@ -882,7 +919,8 @@ struct ConstantIntOrdering { return LHS->getValue().ult(RHS->getValue()); } }; -} + +} // end anonymous namespace static int ConstantIntSortPredicate(ConstantInt *const *P1, ConstantInt *const *P2) { @@ -954,7 +992,16 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, TerminatorInst *PTI = Pred->getTerminator(); Value *PCV = isValueEqualityComparison(PTI); // PredCondVal - if (PCV == CV && SafeToMergeTerminators(TI, PTI)) { + if (PCV == CV && TI != PTI) { + SmallSetVector<BasicBlock*, 4> FailBlocks; + if (!SafeToMergeTerminators(TI, PTI, &FailBlocks)) { + for (auto *Succ : FailBlocks) { + std::vector<BasicBlock*> Blocks = { TI->getParent() }; + if (!SplitBlockPredecessors(Succ, Blocks, ".fold.split")) + return false; + } + } + // Figure out which 'cases' to copy from SI to PSI. std::vector<ValueEqualityComparisonCase> BBCases; BasicBlock *BBDefault = GetValueEqualityComparisonCases(TI, BBCases); @@ -1215,7 +1262,7 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, BIParent->getInstList().splice(BI->getIterator(), BB1->getInstList(), I1); if (!I2->use_empty()) I2->replaceAllUsesWith(I1); - I1->intersectOptionalDataWith(I2); + I1->andIRFlags(I2); unsigned KnownIDs[] = {LLVMContext::MD_tbaa, LLVMContext::MD_range, LLVMContext::MD_fpmath, @@ -1227,6 +1274,13 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, LLVMContext::MD_dereferenceable_or_null, LLVMContext::MD_mem_parallel_loop_access}; combineMetadata(I1, I2, KnownIDs); + + // I1 and I2 are being combined into a single instruction. Its debug + // location is the merged locations of the original instructions. + if (!isa<CallInst>(I1)) + I1->setDebugLoc( + DILocation::getMergedLocation(I1->getDebugLoc(), I2->getDebugLoc())); + I2->eraseFromParent(); Changed = true; @@ -1319,172 +1373,462 @@ HoistTerminator: return true; } -/// Given an unconditional branch that goes to BBEnd, -/// check whether BBEnd has only two predecessors and the other predecessor -/// ends with an unconditional branch. If it is true, sink any common code -/// in the two predecessors to BBEnd. -static bool SinkThenElseCodeToEnd(BranchInst *BI1) { - assert(BI1->isUnconditional()); - BasicBlock *BB1 = BI1->getParent(); - BasicBlock *BBEnd = BI1->getSuccessor(0); - - // Check that BBEnd has two predecessors and the other predecessor ends with - // an unconditional branch. - pred_iterator PI = pred_begin(BBEnd), PE = pred_end(BBEnd); - BasicBlock *Pred0 = *PI++; - if (PI == PE) // Only one predecessor. - return false; - BasicBlock *Pred1 = *PI++; - if (PI != PE) // More than two predecessors. +// Is it legal to place a variable in operand \c OpIdx of \c I? +// FIXME: This should be promoted to Instruction. +static bool canReplaceOperandWithVariable(const Instruction *I, + unsigned OpIdx) { + // We can't have a PHI with a metadata type. + if (I->getOperand(OpIdx)->getType()->isMetadataTy()) return false; - BasicBlock *BB2 = (Pred0 == BB1) ? Pred1 : Pred0; - BranchInst *BI2 = dyn_cast<BranchInst>(BB2->getTerminator()); - if (!BI2 || !BI2->isUnconditional()) + + // Early exit. + if (!isa<Constant>(I->getOperand(OpIdx))) + return true; + + switch (I->getOpcode()) { + default: + return true; + case Instruction::Call: + case Instruction::Invoke: + // FIXME: many arithmetic intrinsics have no issue taking a + // variable, however it's hard to distingish these from + // specials such as @llvm.frameaddress that require a constant. + if (isa<IntrinsicInst>(I)) + return false; + + // Constant bundle operands may need to retain their constant-ness for + // correctness. + if (ImmutableCallSite(I).isBundleOperand(OpIdx)) + return false; + + return true; + + case Instruction::ShuffleVector: + // Shufflevector masks are constant. + return OpIdx != 2; + case Instruction::ExtractValue: + case Instruction::InsertValue: + // All operands apart from the first are constant. + return OpIdx == 0; + case Instruction::Alloca: return false; + case Instruction::GetElementPtr: + if (OpIdx == 0) + return true; + gep_type_iterator It = std::next(gep_type_begin(I), OpIdx - 1); + return It.isSequential(); + } +} - // Gather the PHI nodes in BBEnd. - SmallDenseMap<std::pair<Value *, Value *>, PHINode *> JointValueMap; - Instruction *FirstNonPhiInBBEnd = nullptr; - for (BasicBlock::iterator I = BBEnd->begin(), E = BBEnd->end(); I != E; ++I) { - if (PHINode *PN = dyn_cast<PHINode>(I)) { - Value *BB1V = PN->getIncomingValueForBlock(BB1); - Value *BB2V = PN->getIncomingValueForBlock(BB2); - JointValueMap[std::make_pair(BB1V, BB2V)] = PN; - } else { - FirstNonPhiInBBEnd = &*I; - break; - } +// All instructions in Insts belong to different blocks that all unconditionally +// branch to a common successor. Analyze each instruction and return true if it +// would be possible to sink them into their successor, creating one common +// instruction instead. For every value that would be required to be provided by +// PHI node (because an operand varies in each input block), add to PHIOperands. +static bool canSinkInstructions( + ArrayRef<Instruction *> Insts, + DenseMap<Instruction *, SmallVector<Value *, 4>> &PHIOperands) { + // Prune out obviously bad instructions to move. Any non-store instruction + // must have exactly one use, and we check later that use is by a single, + // common PHI instruction in the successor. + for (auto *I : Insts) { + // These instructions may change or break semantics if moved. + if (isa<PHINode>(I) || I->isEHPad() || isa<AllocaInst>(I) || + I->getType()->isTokenTy()) + return false; + + // Conservatively return false if I is an inline-asm instruction. Sinking + // and merging inline-asm instructions can potentially create arguments + // that cannot satisfy the inline-asm constraints. + if (const auto *C = dyn_cast<CallInst>(I)) + if (C->isInlineAsm()) + return false; + + // Everything must have only one use too, apart from stores which + // have no uses. + if (!isa<StoreInst>(I) && !I->hasOneUse()) + return false; } - if (!FirstNonPhiInBBEnd) - return false; - // This does very trivial matching, with limited scanning, to find identical - // instructions in the two blocks. We scan backward for obviously identical - // instructions in an identical order. - BasicBlock::InstListType::reverse_iterator RI1 = BB1->getInstList().rbegin(), - RE1 = BB1->getInstList().rend(), - RI2 = BB2->getInstList().rbegin(), - RE2 = BB2->getInstList().rend(); - // Skip debug info. - while (RI1 != RE1 && isa<DbgInfoIntrinsic>(&*RI1)) - ++RI1; - if (RI1 == RE1) - return false; - while (RI2 != RE2 && isa<DbgInfoIntrinsic>(&*RI2)) - ++RI2; - if (RI2 == RE2) - return false; - // Skip the unconditional branches. - ++RI1; - ++RI2; + const Instruction *I0 = Insts.front(); + for (auto *I : Insts) + if (!I->isSameOperationAs(I0)) + return false; - bool Changed = false; - while (RI1 != RE1 && RI2 != RE2) { - // Skip debug info. - while (RI1 != RE1 && isa<DbgInfoIntrinsic>(&*RI1)) - ++RI1; - if (RI1 == RE1) - return Changed; - while (RI2 != RE2 && isa<DbgInfoIntrinsic>(&*RI2)) - ++RI2; - if (RI2 == RE2) - return Changed; + // All instructions in Insts are known to be the same opcode. If they aren't + // stores, check the only user of each is a PHI or in the same block as the + // instruction, because if a user is in the same block as an instruction + // we're contemplating sinking, it must already be determined to be sinkable. + if (!isa<StoreInst>(I0)) { + auto *PNUse = dyn_cast<PHINode>(*I0->user_begin()); + auto *Succ = I0->getParent()->getTerminator()->getSuccessor(0); + if (!all_of(Insts, [&PNUse,&Succ](const Instruction *I) -> bool { + auto *U = cast<Instruction>(*I->user_begin()); + return (PNUse && + PNUse->getParent() == Succ && + PNUse->getIncomingValueForBlock(I->getParent()) == I) || + U->getParent() == I->getParent(); + })) + return false; + } - Instruction *I1 = &*RI1, *I2 = &*RI2; - auto InstPair = std::make_pair(I1, I2); - // I1 and I2 should have a single use in the same PHI node, and they - // perform the same operation. - // Cannot move control-flow-involving, volatile loads, vaarg, etc. - if (isa<PHINode>(I1) || isa<PHINode>(I2) || isa<TerminatorInst>(I1) || - isa<TerminatorInst>(I2) || I1->isEHPad() || I2->isEHPad() || - isa<AllocaInst>(I1) || isa<AllocaInst>(I2) || - I1->mayHaveSideEffects() || I2->mayHaveSideEffects() || - I1->mayReadOrWriteMemory() || I2->mayReadOrWriteMemory() || - !I1->hasOneUse() || !I2->hasOneUse() || !JointValueMap.count(InstPair)) - return Changed; + for (unsigned OI = 0, OE = I0->getNumOperands(); OI != OE; ++OI) { + if (I0->getOperand(OI)->getType()->isTokenTy()) + // Don't touch any operand of token type. + return false; + + // Because SROA can't handle speculating stores of selects, try not + // to sink loads or stores of allocas when we'd have to create a PHI for + // the address operand. Also, because it is likely that loads or stores + // of allocas will disappear when Mem2Reg/SROA is run, don't sink them. + // This can cause code churn which can have unintended consequences down + // the line - see https://llvm.org/bugs/show_bug.cgi?id=30244. + // FIXME: This is a workaround for a deficiency in SROA - see + // https://llvm.org/bugs/show_bug.cgi?id=30188 + if (OI == 1 && isa<StoreInst>(I0) && + any_of(Insts, [](const Instruction *I) { + return isa<AllocaInst>(I->getOperand(1)); + })) + return false; + if (OI == 0 && isa<LoadInst>(I0) && any_of(Insts, [](const Instruction *I) { + return isa<AllocaInst>(I->getOperand(0)); + })) + return false; - // Check whether we should swap the operands of ICmpInst. - // TODO: Add support of communativity. - ICmpInst *ICmp1 = dyn_cast<ICmpInst>(I1), *ICmp2 = dyn_cast<ICmpInst>(I2); - bool SwapOpnds = false; - if (ICmp1 && ICmp2 && ICmp1->getOperand(0) != ICmp2->getOperand(0) && - ICmp1->getOperand(1) != ICmp2->getOperand(1) && - (ICmp1->getOperand(0) == ICmp2->getOperand(1) || - ICmp1->getOperand(1) == ICmp2->getOperand(0))) { - ICmp2->swapOperands(); - SwapOpnds = true; + auto SameAsI0 = [&I0, OI](const Instruction *I) { + assert(I->getNumOperands() == I0->getNumOperands()); + return I->getOperand(OI) == I0->getOperand(OI); + }; + if (!all_of(Insts, SameAsI0)) { + if (!canReplaceOperandWithVariable(I0, OI)) + // We can't create a PHI from this GEP. + return false; + // Don't create indirect calls! The called value is the final operand. + if ((isa<CallInst>(I0) || isa<InvokeInst>(I0)) && OI == OE - 1) { + // FIXME: if the call was *already* indirect, we should do this. + return false; + } + for (auto *I : Insts) + PHIOperands[I].push_back(I->getOperand(OI)); } - if (!I1->isSameOperationAs(I2)) { - if (SwapOpnds) - ICmp2->swapOperands(); - return Changed; + } + return true; +} + +// Assuming canSinkLastInstruction(Blocks) has returned true, sink the last +// instruction of every block in Blocks to their common successor, commoning +// into one instruction. +static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) { + auto *BBEnd = Blocks[0]->getTerminator()->getSuccessor(0); + + // canSinkLastInstruction returning true guarantees that every block has at + // least one non-terminator instruction. + SmallVector<Instruction*,4> Insts; + for (auto *BB : Blocks) { + Instruction *I = BB->getTerminator(); + do { + I = I->getPrevNode(); + } while (isa<DbgInfoIntrinsic>(I) && I != &BB->front()); + if (!isa<DbgInfoIntrinsic>(I)) + Insts.push_back(I); + } + + // The only checking we need to do now is that all users of all instructions + // are the same PHI node. canSinkLastInstruction should have checked this but + // it is slightly over-aggressive - it gets confused by commutative instructions + // so double-check it here. + Instruction *I0 = Insts.front(); + if (!isa<StoreInst>(I0)) { + auto *PNUse = dyn_cast<PHINode>(*I0->user_begin()); + if (!all_of(Insts, [&PNUse](const Instruction *I) -> bool { + auto *U = cast<Instruction>(*I->user_begin()); + return U == PNUse; + })) + return false; + } + + // We don't need to do any more checking here; canSinkLastInstruction should + // have done it all for us. + SmallVector<Value*, 4> NewOperands; + for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O) { + // This check is different to that in canSinkLastInstruction. There, we + // cared about the global view once simplifycfg (and instcombine) have + // completed - it takes into account PHIs that become trivially + // simplifiable. However here we need a more local view; if an operand + // differs we create a PHI and rely on instcombine to clean up the very + // small mess we may make. + bool NeedPHI = any_of(Insts, [&I0, O](const Instruction *I) { + return I->getOperand(O) != I0->getOperand(O); + }); + if (!NeedPHI) { + NewOperands.push_back(I0->getOperand(O)); + continue; } - // The operands should be either the same or they need to be generated - // with a PHI node after sinking. We only handle the case where there is - // a single pair of different operands. - Value *DifferentOp1 = nullptr, *DifferentOp2 = nullptr; - unsigned Op1Idx = ~0U; - for (unsigned I = 0, E = I1->getNumOperands(); I != E; ++I) { - if (I1->getOperand(I) == I2->getOperand(I)) - continue; - // Early exit if we have more-than one pair of different operands or if - // we need a PHI node to replace a constant. - if (Op1Idx != ~0U || isa<Constant>(I1->getOperand(I)) || - isa<Constant>(I2->getOperand(I))) { - // If we can't sink the instructions, undo the swapping. - if (SwapOpnds) - ICmp2->swapOperands(); - return Changed; + // Create a new PHI in the successor block and populate it. + auto *Op = I0->getOperand(O); + assert(!Op->getType()->isTokenTy() && "Can't PHI tokens!"); + auto *PN = PHINode::Create(Op->getType(), Insts.size(), + Op->getName() + ".sink", &BBEnd->front()); + for (auto *I : Insts) + PN->addIncoming(I->getOperand(O), I->getParent()); + NewOperands.push_back(PN); + } + + // Arbitrarily use I0 as the new "common" instruction; remap its operands + // and move it to the start of the successor block. + for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O) + I0->getOperandUse(O).set(NewOperands[O]); + I0->moveBefore(&*BBEnd->getFirstInsertionPt()); + + // The debug location for the "common" instruction is the merged locations of + // all the commoned instructions. We start with the original location of the + // "common" instruction and iteratively merge each location in the loop below. + const DILocation *Loc = I0->getDebugLoc(); + + // Update metadata and IR flags, and merge debug locations. + for (auto *I : Insts) + if (I != I0) { + Loc = DILocation::getMergedLocation(Loc, I->getDebugLoc()); + combineMetadataForCSE(I0, I); + I0->andIRFlags(I); + } + if (!isa<CallInst>(I0)) + I0->setDebugLoc(Loc); + + if (!isa<StoreInst>(I0)) { + // canSinkLastInstruction checked that all instructions were used by + // one and only one PHI node. Find that now, RAUW it to our common + // instruction and nuke it. + assert(I0->hasOneUse()); + auto *PN = cast<PHINode>(*I0->user_begin()); + PN->replaceAllUsesWith(I0); + PN->eraseFromParent(); + } + + // Finally nuke all instructions apart from the common instruction. + for (auto *I : Insts) + if (I != I0) + I->eraseFromParent(); + + return true; +} + +namespace { + + // LockstepReverseIterator - Iterates through instructions + // in a set of blocks in reverse order from the first non-terminator. + // For example (assume all blocks have size n): + // LockstepReverseIterator I([B1, B2, B3]); + // *I-- = [B1[n], B2[n], B3[n]]; + // *I-- = [B1[n-1], B2[n-1], B3[n-1]]; + // *I-- = [B1[n-2], B2[n-2], B3[n-2]]; + // ... + class LockstepReverseIterator { + ArrayRef<BasicBlock*> Blocks; + SmallVector<Instruction*,4> Insts; + bool Fail; + public: + LockstepReverseIterator(ArrayRef<BasicBlock*> Blocks) : + Blocks(Blocks) { + reset(); + } + + void reset() { + Fail = false; + Insts.clear(); + for (auto *BB : Blocks) { + Instruction *Inst = BB->getTerminator(); + for (Inst = Inst->getPrevNode(); Inst && isa<DbgInfoIntrinsic>(Inst);) + Inst = Inst->getPrevNode(); + if (!Inst) { + // Block wasn't big enough. + Fail = true; + return; + } + Insts.push_back(Inst); } - DifferentOp1 = I1->getOperand(I); - Op1Idx = I; - DifferentOp2 = I2->getOperand(I); } - DEBUG(dbgs() << "SINK common instructions " << *I1 << "\n"); - DEBUG(dbgs() << " " << *I2 << "\n"); - - // We insert the pair of different operands to JointValueMap and - // remove (I1, I2) from JointValueMap. - if (Op1Idx != ~0U) { - auto &NewPN = JointValueMap[std::make_pair(DifferentOp1, DifferentOp2)]; - if (!NewPN) { - NewPN = - PHINode::Create(DifferentOp1->getType(), 2, - DifferentOp1->getName() + ".sink", &BBEnd->front()); - NewPN->addIncoming(DifferentOp1, BB1); - NewPN->addIncoming(DifferentOp2, BB2); - DEBUG(dbgs() << "Create PHI node " << *NewPN << "\n";); + bool isValid() const { + return !Fail; + } + + void operator -- () { + if (Fail) + return; + for (auto *&Inst : Insts) { + for (Inst = Inst->getPrevNode(); Inst && isa<DbgInfoIntrinsic>(Inst);) + Inst = Inst->getPrevNode(); + // Already at beginning of block. + if (!Inst) { + Fail = true; + return; + } } - // I1 should use NewPN instead of DifferentOp1. - I1->setOperand(Op1Idx, NewPN); } - PHINode *OldPN = JointValueMap[InstPair]; - JointValueMap.erase(InstPair); - - // We need to update RE1 and RE2 if we are going to sink the first - // instruction in the basic block down. - bool UpdateRE1 = (I1 == &BB1->front()), UpdateRE2 = (I2 == &BB2->front()); - // Sink the instruction. - BBEnd->getInstList().splice(FirstNonPhiInBBEnd->getIterator(), - BB1->getInstList(), I1); - if (!OldPN->use_empty()) - OldPN->replaceAllUsesWith(I1); - OldPN->eraseFromParent(); - if (!I2->use_empty()) - I2->replaceAllUsesWith(I1); - I1->intersectOptionalDataWith(I2); - // TODO: Use combineMetadata here to preserve what metadata we can - // (analogous to the hoisting case above). - I2->eraseFromParent(); + ArrayRef<Instruction*> operator * () const { + return Insts; + } + }; + +} // end anonymous namespace - if (UpdateRE1) - RE1 = BB1->getInstList().rend(); - if (UpdateRE2) - RE2 = BB2->getInstList().rend(); - FirstNonPhiInBBEnd = &*I1; +/// Given an unconditional branch that goes to BBEnd, +/// check whether BBEnd has only two predecessors and the other predecessor +/// ends with an unconditional branch. If it is true, sink any common code +/// in the two predecessors to BBEnd. +static bool SinkThenElseCodeToEnd(BranchInst *BI1) { + assert(BI1->isUnconditional()); + BasicBlock *BBEnd = BI1->getSuccessor(0); + + // We support two situations: + // (1) all incoming arcs are unconditional + // (2) one incoming arc is conditional + // + // (2) is very common in switch defaults and + // else-if patterns; + // + // if (a) f(1); + // else if (b) f(2); + // + // produces: + // + // [if] + // / \ + // [f(1)] [if] + // | | \ + // | | \ + // | [f(2)]| + // \ | / + // [ end ] + // + // [end] has two unconditional predecessor arcs and one conditional. The + // conditional refers to the implicit empty 'else' arc. This conditional + // arc can also be caused by an empty default block in a switch. + // + // In this case, we attempt to sink code from all *unconditional* arcs. + // If we can sink instructions from these arcs (determined during the scan + // phase below) we insert a common successor for all unconditional arcs and + // connect that to [end], to enable sinking: + // + // [if] + // / \ + // [x(1)] [if] + // | | \ + // | | \ + // | [x(2)] | + // \ / | + // [sink.split] | + // \ / + // [ end ] + // + SmallVector<BasicBlock*,4> UnconditionalPreds; + Instruction *Cond = nullptr; + for (auto *B : predecessors(BBEnd)) { + auto *T = B->getTerminator(); + if (isa<BranchInst>(T) && cast<BranchInst>(T)->isUnconditional()) + UnconditionalPreds.push_back(B); + else if ((isa<BranchInst>(T) || isa<SwitchInst>(T)) && !Cond) + Cond = T; + else + return false; + } + if (UnconditionalPreds.size() < 2) + return false; + + bool Changed = false; + // We take a two-step approach to tail sinking. First we scan from the end of + // each block upwards in lockstep. If the n'th instruction from the end of each + // block can be sunk, those instructions are added to ValuesToSink and we + // carry on. If we can sink an instruction but need to PHI-merge some operands + // (because they're not identical in each instruction) we add these to + // PHIOperands. + unsigned ScanIdx = 0; + SmallPtrSet<Value*,4> InstructionsToSink; + DenseMap<Instruction*, SmallVector<Value*,4>> PHIOperands; + LockstepReverseIterator LRI(UnconditionalPreds); + while (LRI.isValid() && + canSinkInstructions(*LRI, PHIOperands)) { + DEBUG(dbgs() << "SINK: instruction can be sunk: " << *(*LRI)[0] << "\n"); + InstructionsToSink.insert((*LRI).begin(), (*LRI).end()); + ++ScanIdx; + --LRI; + } + + auto ProfitableToSinkInstruction = [&](LockstepReverseIterator &LRI) { + unsigned NumPHIdValues = 0; + for (auto *I : *LRI) + for (auto *V : PHIOperands[I]) + if (InstructionsToSink.count(V) == 0) + ++NumPHIdValues; + DEBUG(dbgs() << "SINK: #phid values: " << NumPHIdValues << "\n"); + unsigned NumPHIInsts = NumPHIdValues / UnconditionalPreds.size(); + if ((NumPHIdValues % UnconditionalPreds.size()) != 0) + NumPHIInsts++; + + return NumPHIInsts <= 1; + }; + + if (ScanIdx > 0 && Cond) { + // Check if we would actually sink anything first! This mutates the CFG and + // adds an extra block. The goal in doing this is to allow instructions that + // couldn't be sunk before to be sunk - obviously, speculatable instructions + // (such as trunc, add) can be sunk and predicated already. So we check that + // we're going to sink at least one non-speculatable instruction. + LRI.reset(); + unsigned Idx = 0; + bool Profitable = false; + while (ProfitableToSinkInstruction(LRI) && Idx < ScanIdx) { + if (!isSafeToSpeculativelyExecute((*LRI)[0])) { + Profitable = true; + break; + } + --LRI; + ++Idx; + } + if (!Profitable) + return false; + + DEBUG(dbgs() << "SINK: Splitting edge\n"); + // We have a conditional edge and we're going to sink some instructions. + // Insert a new block postdominating all blocks we're going to sink from. + if (!SplitBlockPredecessors(BI1->getSuccessor(0), UnconditionalPreds, + ".sink.split")) + // Edges couldn't be split. + return false; + Changed = true; + } + + // Now that we've analyzed all potential sinking candidates, perform the + // actual sink. We iteratively sink the last non-terminator of the source + // blocks into their common successor unless doing so would require too + // many PHI instructions to be generated (currently only one PHI is allowed + // per sunk instruction). + // + // We can use InstructionsToSink to discount values needing PHI-merging that will + // actually be sunk in a later iteration. This allows us to be more + // aggressive in what we sink. This does allow a false positive where we + // sink presuming a later value will also be sunk, but stop half way through + // and never actually sink it which means we produce more PHIs than intended. + // This is unlikely in practice though. + for (unsigned SinkIdx = 0; SinkIdx != ScanIdx; ++SinkIdx) { + DEBUG(dbgs() << "SINK: Sink: " + << *UnconditionalPreds[0]->getTerminator()->getPrevNode() + << "\n"); + + // Because we've sunk every instruction in turn, the current instruction to + // sink is always at index 0. + LRI.reset(); + if (!ProfitableToSinkInstruction(LRI)) { + // Too many PHIs would be created. + DEBUG(dbgs() << "SINK: stopping here, too many PHIs would be created!\n"); + break; + } + + if (!sinkLastInstruction(UnconditionalPreds)) + return Changed; NumSinkCommons++; Changed = true; } @@ -1539,7 +1883,7 @@ static Value *isSafeToSpeculateStore(Instruction *I, BasicBlock *BrBB, continue; --MaxNumInstToLookAt; - // Could be calling an instruction that effects memory like free(). + // Could be calling an instruction that affects memory like free(). if (CurI.mayHaveSideEffects() && !isa<StoreInst>(CurI)) return nullptr; @@ -1822,7 +2166,7 @@ static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL) { return false; // Can't fold blocks that contain noduplicate or convergent calls. - if (llvm::any_of(*BB, [](const Instruction &I) { + if (any_of(*BB, [](const Instruction &I) { const CallInst *CI = dyn_cast<CallInst>(&I); return CI && (CI->cannotDuplicate() || CI->isConvergent()); })) @@ -2464,6 +2808,11 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { PBI = New_PBI; } + // If BI was a loop latch, it may have had associated loop metadata. + // We need to copy it to the new latch, that is, PBI. + if (MDNode *LoopMD = BI->getMetadata(LLVMContext::MD_loop)) + PBI->setMetadata(LLVMContext::MD_loop, LoopMD); + // TODO: If BB is reachable from all paths through PredBlock, then we // could replace PBI's branch probabilities with BI's. @@ -4150,18 +4499,28 @@ static bool ForwardSwitchConditionToPHI(SwitchInst *SI) { /// Return true if the backend will be able to handle /// initializing an array of constants like C. -static bool ValidLookupTableConstant(Constant *C) { +static bool ValidLookupTableConstant(Constant *C, const TargetTransformInfo &TTI) { if (C->isThreadDependent()) return false; if (C->isDLLImportDependent()) return false; - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) - return CE->isGEPWithNoNotionalOverIndexing(); + if (!isa<ConstantFP>(C) && !isa<ConstantInt>(C) && + !isa<ConstantPointerNull>(C) && !isa<GlobalValue>(C) && + !isa<UndefValue>(C) && !isa<ConstantExpr>(C)) + return false; + + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) { + if (!CE->isGEPWithNoNotionalOverIndexing()) + return false; + if (!ValidLookupTableConstant(CE->getOperand(0), TTI)) + return false; + } + + if (!TTI.shouldBuildLookupTablesForConstant(C)) + return false; - return isa<ConstantFP>(C) || isa<ConstantInt>(C) || - isa<ConstantPointerNull>(C) || isa<GlobalValue>(C) || - isa<UndefValue>(C); + return true; } /// If V is a Constant, return it. Otherwise, try to look up @@ -4216,7 +4575,7 @@ static bool GetCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, BasicBlock **CommonDest, SmallVectorImpl<std::pair<PHINode *, Constant *>> &Res, - const DataLayout &DL) { + const DataLayout &DL, const TargetTransformInfo &TTI) { // The block from which we enter the common destination. BasicBlock *Pred = SI->getParent(); @@ -4228,7 +4587,7 @@ GetCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, ++I) { if (TerminatorInst *T = dyn_cast<TerminatorInst>(I)) { // If the terminator is a simple branch, continue to the next block. - if (T->getNumSuccessors() != 1) + if (T->getNumSuccessors() != 1 || T->isExceptional()) return false; Pred = CaseDest; CaseDest = T->getSuccessor(0); @@ -4278,7 +4637,7 @@ GetCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, return false; // Be conservative about which kinds of constants we support. - if (!ValidLookupTableConstant(ConstVal)) + if (!ValidLookupTableConstant(ConstVal, TTI)) return false; Res.push_back(std::make_pair(PHI, ConstVal)); @@ -4310,14 +4669,15 @@ static bool InitializeUniqueCases(SwitchInst *SI, PHINode *&PHI, BasicBlock *&CommonDest, SwitchCaseResultVectorTy &UniqueResults, Constant *&DefaultResult, - const DataLayout &DL) { + const DataLayout &DL, + const TargetTransformInfo &TTI) { for (auto &I : SI->cases()) { ConstantInt *CaseVal = I.getCaseValue(); // Resulting value at phi nodes for this case value. SwitchCaseResultsTy Results; if (!GetCaseResults(SI, CaseVal, I.getCaseSuccessor(), &CommonDest, Results, - DL)) + DL, TTI)) return false; // Only one value per case is permitted @@ -4335,7 +4695,7 @@ static bool InitializeUniqueCases(SwitchInst *SI, PHINode *&PHI, SmallVector<std::pair<PHINode *, Constant *>, 1> DefaultResults; BasicBlock *DefaultDest = SI->getDefaultDest(); GetCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, DefaultResults, - DL); + DL, TTI); // If the default value is not found abort unless the default destination // is unreachable. DefaultResult = @@ -4414,7 +4774,8 @@ static void RemoveSwitchAfterSelectConversion(SwitchInst *SI, PHINode *PHI, /// phi nodes in a common successor block with only two different /// constant values, replace the switch with select. static bool SwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder, - AssumptionCache *AC, const DataLayout &DL) { + AssumptionCache *AC, const DataLayout &DL, + const TargetTransformInfo &TTI) { Value *const Cond = SI->getCondition(); PHINode *PHI = nullptr; BasicBlock *CommonDest = nullptr; @@ -4422,7 +4783,7 @@ static bool SwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder, SwitchCaseResultVectorTy UniqueResults; // Collect all the cases that will deliver the same value from the switch. if (!InitializeUniqueCases(SI, PHI, CommonDest, UniqueResults, DefaultResult, - DL)) + DL, TTI)) return false; // Selects choose between maximum two values. if (UniqueResults.size() != 2) @@ -4441,6 +4802,7 @@ static bool SwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder, } namespace { + /// This class represents a lookup table that can be used to replace a switch. class SwitchLookupTable { public: @@ -4497,7 +4859,8 @@ private: // For ArrayKind, this is the array. GlobalVariable *Array; }; -} + +} // end anonymous namespace SwitchLookupTable::SwitchLookupTable( Module &M, uint64_t TableSize, ConstantInt *Offset, @@ -4860,7 +5223,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, typedef SmallVector<std::pair<PHINode *, Constant *>, 4> ResultsTy; ResultsTy Results; if (!GetCaseResults(SI, CaseVal, CI.getCaseSuccessor(), &CommonDest, - Results, DL)) + Results, DL, TTI)) return false; // Append the result from this case to the list for each phi. @@ -4886,8 +5249,9 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // If the table has holes, we need a constant result for the default case // or a bitmask that fits in a register. SmallVector<std::pair<PHINode *, Constant *>, 4> DefaultResultsList; - bool HasDefaultResults = GetCaseResults(SI, nullptr, SI->getDefaultDest(), - &CommonDest, DefaultResultsList, DL); + bool HasDefaultResults = + GetCaseResults(SI, nullptr, SI->getDefaultDest(), &CommonDest, + DefaultResultsList, DL, TTI); bool NeedMask = (TableHasHoles && !HasDefaultResults); if (NeedMask) { @@ -5044,6 +5408,111 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, return true; } +static bool isSwitchDense(ArrayRef<int64_t> Values) { + // See also SelectionDAGBuilder::isDense(), which this function was based on. + uint64_t Diff = (uint64_t)Values.back() - (uint64_t)Values.front(); + uint64_t Range = Diff + 1; + uint64_t NumCases = Values.size(); + // 40% is the default density for building a jump table in optsize/minsize mode. + uint64_t MinDensity = 40; + + return NumCases * 100 >= Range * MinDensity; +} + +// Try and transform a switch that has "holes" in it to a contiguous sequence +// of cases. +// +// A switch such as: switch(i) {case 5: case 9: case 13: case 17:} can be +// range-reduced to: switch ((i-5) / 4) {case 0: case 1: case 2: case 3:}. +// +// This converts a sparse switch into a dense switch which allows better +// lowering and could also allow transforming into a lookup table. +static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, + const DataLayout &DL, + const TargetTransformInfo &TTI) { + auto *CondTy = cast<IntegerType>(SI->getCondition()->getType()); + if (CondTy->getIntegerBitWidth() > 64 || + !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth())) + return false; + // Only bother with this optimization if there are more than 3 switch cases; + // SDAG will only bother creating jump tables for 4 or more cases. + if (SI->getNumCases() < 4) + return false; + + // This transform is agnostic to the signedness of the input or case values. We + // can treat the case values as signed or unsigned. We can optimize more common + // cases such as a sequence crossing zero {-4,0,4,8} if we interpret case values + // as signed. + SmallVector<int64_t,4> Values; + for (auto &C : SI->cases()) + Values.push_back(C.getCaseValue()->getValue().getSExtValue()); + std::sort(Values.begin(), Values.end()); + + // If the switch is already dense, there's nothing useful to do here. + if (isSwitchDense(Values)) + return false; + + // First, transform the values such that they start at zero and ascend. + int64_t Base = Values[0]; + for (auto &V : Values) + V -= Base; + + // Now we have signed numbers that have been shifted so that, given enough + // precision, there are no negative values. Since the rest of the transform + // is bitwise only, we switch now to an unsigned representation. + uint64_t GCD = 0; + for (auto &V : Values) + GCD = GreatestCommonDivisor64(GCD, (uint64_t)V); + + // This transform can be done speculatively because it is so cheap - it results + // in a single rotate operation being inserted. This can only happen if the + // factor extracted is a power of 2. + // FIXME: If the GCD is an odd number we can multiply by the multiplicative + // inverse of GCD and then perform this transform. + // FIXME: It's possible that optimizing a switch on powers of two might also + // be beneficial - flag values are often powers of two and we could use a CLZ + // as the key function. + if (GCD <= 1 || !isPowerOf2_64(GCD)) + // No common divisor found or too expensive to compute key function. + return false; + + unsigned Shift = Log2_64(GCD); + for (auto &V : Values) + V = (int64_t)((uint64_t)V >> Shift); + + if (!isSwitchDense(Values)) + // Transform didn't create a dense switch. + return false; + + // The obvious transform is to shift the switch condition right and emit a + // check that the condition actually cleanly divided by GCD, i.e. + // C & (1 << Shift - 1) == 0 + // inserting a new CFG edge to handle the case where it didn't divide cleanly. + // + // A cheaper way of doing this is a simple ROTR(C, Shift). This performs the + // shift and puts the shifted-off bits in the uppermost bits. If any of these + // are nonzero then the switch condition will be very large and will hit the + // default case. + + auto *Ty = cast<IntegerType>(SI->getCondition()->getType()); + Builder.SetInsertPoint(SI); + auto *ShiftC = ConstantInt::get(Ty, Shift); + auto *Sub = Builder.CreateSub(SI->getCondition(), ConstantInt::get(Ty, Base)); + auto *LShr = Builder.CreateLShr(Sub, ShiftC); + auto *Shl = Builder.CreateShl(Sub, Ty->getBitWidth() - Shift); + auto *Rot = Builder.CreateOr(LShr, Shl); + SI->replaceUsesOfWith(SI->getCondition(), Rot); + + for (SwitchInst::CaseIt C = SI->case_begin(), E = SI->case_end(); C != E; + ++C) { + auto *Orig = C.getCaseValue(); + auto Sub = Orig->getValue() - APInt(Ty->getBitWidth(), Base); + C.setValue( + cast<ConstantInt>(ConstantInt::get(Ty, Sub.lshr(ShiftC->getValue())))); + } + return true; +} + bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { BasicBlock *BB = SI->getParent(); @@ -5078,7 +5547,7 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { if (EliminateDeadSwitchCases(SI, AC, DL)) return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; - if (SwitchToSelect(SI, Builder, AC, DL)) + if (SwitchToSelect(SI, Builder, AC, DL, TTI)) return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; if (ForwardSwitchConditionToPHI(SI)) @@ -5087,6 +5556,9 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { if (SwitchToLookupTable(SI, Builder, DL, TTI)) return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + if (ReduceSwitchRange(SI, Builder, DL, TTI)) + return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return false; } @@ -5397,7 +5869,10 @@ static bool passingValueIsAlwaysUndefined(Value *V, Instruction *I) { // Now make sure that there are no instructions in between that can alter // control flow (eg. calls) - for (BasicBlock::iterator i = ++BasicBlock::iterator(I); &*i != Use; ++i) + for (BasicBlock::iterator + i = ++BasicBlock::iterator(I), + UI = BasicBlock::iterator(dyn_cast<Instruction>(Use)); + i != UI; ++i) if (i == I->getParent()->end() || i->mayHaveSideEffects()) return false; diff --git a/contrib/llvm/lib/Transforms/Utils/SimplifyInstructions.cpp b/contrib/llvm/lib/Transforms/Utils/SimplifyInstructions.cpp index df29906..1220490 100644 --- a/contrib/llvm/lib/Transforms/Utils/SimplifyInstructions.cpp +++ b/contrib/llvm/lib/Transforms/Utils/SimplifyInstructions.cpp @@ -34,14 +34,14 @@ using namespace llvm; STATISTIC(NumSimplified, "Number of redundant instructions removed"); -static bool runImpl(Function &F, const DominatorTree *DT, const TargetLibraryInfo *TLI, - AssumptionCache *AC) { +static bool runImpl(Function &F, const DominatorTree *DT, + const TargetLibraryInfo *TLI, AssumptionCache *AC) { const DataLayout &DL = F.getParent()->getDataLayout(); - SmallPtrSet<const Instruction*, 8> S1, S2, *ToSimplify = &S1, *Next = &S2; + SmallPtrSet<const Instruction *, 8> S1, S2, *ToSimplify = &S1, *Next = &S2; bool Changed = false; do { - for (BasicBlock *BB : depth_first(&F.getEntryBlock())) + for (BasicBlock *BB : depth_first(&F.getEntryBlock())) { // Here be subtlety: the iterator must be incremented before the loop // body (not sure why), so a range-for loop won't work here. for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { @@ -51,8 +51,9 @@ static bool runImpl(Function &F, const DominatorTree *DT, const TargetLibraryInf // empty and we only bother simplifying instructions that are in it. if (!ToSimplify->empty() && !ToSimplify->count(I)) continue; + // Don't waste time simplifying unused instructions. - if (!I->use_empty()) + if (!I->use_empty()) { if (Value *V = SimplifyInstruction(I, DL, TLI, DT, AC)) { // Mark all uses for resimplification next time round the loop. for (User *U : I->users()) @@ -61,16 +62,17 @@ static bool runImpl(Function &F, const DominatorTree *DT, const TargetLibraryInf ++NumSimplified; Changed = true; } - bool res = RecursivelyDeleteTriviallyDeadInstructions(I, TLI); - if (res) { - // RecursivelyDeleteTriviallyDeadInstruction can remove - // more than one instruction, so simply incrementing the - // iterator does not work. When instructions get deleted - // re-iterate instead. - BI = BB->begin(); BE = BB->end(); - Changed |= res; + } + if (RecursivelyDeleteTriviallyDeadInstructions(I, TLI)) { + // RecursivelyDeleteTriviallyDeadInstruction can remove more than one + // instruction, so simply incrementing the iterator does not work. + // When instructions get deleted re-iterate instead. + BI = BB->begin(); + BE = BB->end(); + Changed = true; } } + } // Place the list of instructions to simplify on the next loop iteration // into ToSimplify. @@ -90,6 +92,7 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); + AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); } @@ -99,9 +102,8 @@ namespace { if (skipFunction(F)) return false; - const DominatorTreeWrapperPass *DTWP = - getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - const DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; + const DominatorTree *DT = + &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); const TargetLibraryInfo *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); AssumptionCache *AC = @@ -115,6 +117,7 @@ char InstSimplifier::ID = 0; INITIALIZE_PASS_BEGIN(InstSimplifier, "instsimplify", "Remove redundant instructions", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(InstSimplifier, "instsimplify", "Remove redundant instructions", false, false) @@ -126,11 +129,11 @@ FunctionPass *llvm::createInstructionSimplifierPass() { } PreservedAnalyses InstSimplifierPass::run(Function &F, - AnalysisManager<Function> &AM) { - auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); + FunctionAnalysisManager &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); - bool Changed = runImpl(F, DT, &TLI, &AC); + bool Changed = runImpl(F, &DT, &TLI, &AC); if (!Changed) return PreservedAnalyses::all(); // FIXME: This should also 'preserve the CFG'. diff --git a/contrib/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/contrib/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index c298695..8eaeb10 100644 --- a/contrib/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/contrib/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -56,6 +56,38 @@ static bool ignoreCallingConv(LibFunc::Func Func) { Func == LibFunc::llabs || Func == LibFunc::strlen; } +static bool isCallingConvCCompatible(CallInst *CI) { + switch(CI->getCallingConv()) { + default: + return false; + case llvm::CallingConv::C: + return true; + case llvm::CallingConv::ARM_APCS: + case llvm::CallingConv::ARM_AAPCS: + case llvm::CallingConv::ARM_AAPCS_VFP: { + + // The iOS ABI diverges from the standard in some cases, so for now don't + // try to simplify those calls. + if (Triple(CI->getModule()->getTargetTriple()).isiOS()) + return false; + + auto *FuncTy = CI->getFunctionType(); + + if (!FuncTy->getReturnType()->isPointerTy() && + !FuncTy->getReturnType()->isIntegerTy() && + !FuncTy->getReturnType()->isVoidTy()) + return false; + + for (auto Param : FuncTy->params()) { + if (!Param->isPointerTy() && !Param->isIntegerTy()) + return false; + } + return true; + } + } + return false; +} + /// Return true if it only matters that the value is equal or not-equal to zero. static bool isOnlyUsedInZeroEqualityComparison(Value *V) { for (User *U : V->users()) { @@ -83,7 +115,7 @@ static bool isOnlyUsedInEqualityComparison(Value *V, Value *With) { } static bool callHasFloatingPointArgument(const CallInst *CI) { - return std::any_of(CI->op_begin(), CI->op_end(), [](const Use &OI) { + return any_of(CI->operands(), [](const Use &OI) { return OI->getType()->isFloatingPointTy(); }); } @@ -868,7 +900,7 @@ static Value *valueHasFloatPrecision(Value *Val) { if (ConstantFP *Const = dyn_cast<ConstantFP>(Val)) { APFloat F = Const->getValueAPF(); bool losesInfo; - (void)F.convert(APFloat::IEEEsingle, APFloat::rmNearestTiesToEven, + (void)F.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &losesInfo); if (!losesInfo) return ConstantFP::get(Const->getContext(), F); @@ -993,16 +1025,20 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { Ret = optimizeUnaryDoubleFP(CI, B, true); Value *Op1 = CI->getArgOperand(0), *Op2 = CI->getArgOperand(1); + + // pow(1.0, x) -> 1.0 + if (match(Op1, m_SpecificFP(1.0))) + return Op1; + // pow(2.0, x) -> llvm.exp2(x) + if (match(Op1, m_SpecificFP(2.0))) { + Value *Exp2 = Intrinsic::getDeclaration(CI->getModule(), Intrinsic::exp2, + CI->getType()); + return B.CreateCall(Exp2, Op2, "exp2"); + } + + // There's no llvm.exp10 intrinsic yet, but, maybe, some day there will + // be one. if (ConstantFP *Op1C = dyn_cast<ConstantFP>(Op1)) { - // pow(1.0, x) -> 1.0 - if (Op1C->isExactlyValue(1.0)) - return Op1C; - // pow(2.0, x) -> exp2(x) - if (Op1C->isExactlyValue(2.0) && - hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp2, LibFunc::exp2f, - LibFunc::exp2l)) - return emitUnaryFloatFnCall(Op2, TLI->getName(LibFunc::exp2), B, - Callee->getAttributes()); // pow(10.0, x) -> exp10(x) if (Op1C->isExactlyValue(10.0) && hasUnaryFloatFn(TLI, Op1->getType(), LibFunc::exp10, LibFunc::exp10f, @@ -1038,6 +1074,24 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { if (Op2C->getValueAPF().isZero()) // pow(x, 0.0) -> 1.0 return ConstantFP::get(CI->getType(), 1.0); + if (Op2C->isExactlyValue(-0.5) && + hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::sqrt, LibFunc::sqrtf, + LibFunc::sqrtl)) { + // If -ffast-math: + // pow(x, -0.5) -> 1.0 / sqrt(x) + if (CI->hasUnsafeAlgebra()) { + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); + + // Here we cannot lower to an intrinsic because C99 sqrt() and llvm.sqrt + // are not guaranteed to have the same semantics. + Value *Sqrt = emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc::sqrt), B, + Callee->getAttributes()); + + return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Sqrt, "sqrtrecip"); + } + } + if (Op2C->isExactlyValue(0.5) && hasUnaryFloatFn(TLI, Op2->getType(), LibFunc::sqrt, LibFunc::sqrtf, LibFunc::sqrtl) && @@ -1048,6 +1102,9 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { if (CI->hasUnsafeAlgebra()) { IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(CI->getFastMathFlags()); + + // Unlike other math intrinsics, sqrt has differerent semantics + // from the libc function. See LangRef for details. return emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc::sqrt), B, Callee->getAttributes()); } @@ -1082,6 +1139,10 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { !V.isInteger()) return nullptr; + // Propagate fast math flags. + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); + // We will memoize intermediate products of the Addition Chain. Value *InnerChain[33] = {nullptr}; InnerChain[1] = Op1; @@ -1090,9 +1151,8 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { // We cannot readily convert a non-double type (like float) to a double. // So we first convert V to something which could be converted to double. bool ignored; - V.convert(APFloat::IEEEdouble, APFloat::rmTowardZero, &ignored); + V.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &ignored); - // TODO: Should the new instructions propagate the 'fast' flag of the pow()? Value *FMul = getPow(InnerChain, V.convertToDouble(), B); // For negative exponents simply compute the reciprocal. if (Op2C->isNegative()) @@ -1150,19 +1210,11 @@ Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) { Value *LibCallSimplifier::optimizeFabs(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); - Value *Ret = nullptr; StringRef Name = Callee->getName(); if (Name == "fabs" && hasFloatVersion(Name)) - Ret = optimizeUnaryDoubleFP(CI, B, false); + return optimizeUnaryDoubleFP(CI, B, false); - Value *Op = CI->getArgOperand(0); - if (Instruction *I = dyn_cast<Instruction>(Op)) { - // Fold fabs(x * x) -> x * x; any squared FP value must already be positive. - if (I->getOpcode() == Instruction::FMul) - if (I->getOperand(0) == I->getOperand(1)) - return Op; - } - return Ret; + return nullptr; } Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) { @@ -1428,6 +1480,12 @@ Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, IRBuilder<> &B) { Value *Sin, *Cos, *SinCos; insertSinCosCall(B, CI->getCalledFunction(), Arg, IsFloat, Sin, Cos, SinCos); + auto replaceTrigInsts = [this](SmallVectorImpl<CallInst *> &Calls, + Value *Res) { + for (CallInst *C : Calls) + replaceAllUsesWith(C, Res); + }; + replaceTrigInsts(SinCalls, Sin); replaceTrigInsts(CosCalls, Cos); replaceTrigInsts(SinCosCalls, SinCos); @@ -1472,32 +1530,16 @@ void LibCallSimplifier::classifyArgUse( } } -void LibCallSimplifier::replaceTrigInsts(SmallVectorImpl<CallInst *> &Calls, - Value *Res) { - for (CallInst *C : Calls) - replaceAllUsesWith(C, Res); -} - //===----------------------------------------------------------------------===// // Integer Library Call Optimizations //===----------------------------------------------------------------------===// Value *LibCallSimplifier::optimizeFFS(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - Value *Op = CI->getArgOperand(0); - - // Constant fold. - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op)) { - if (CI->isZero()) // ffs(0) -> 0. - return B.getInt32(0); - // ffs(c) -> cttz(c)+1 - return B.getInt32(CI->getValue().countTrailingZeros() + 1); - } - // ffs(x) -> x != 0 ? (i32)llvm.cttz(x)+1 : 0 + Value *Op = CI->getArgOperand(0); Type *ArgType = Op->getType(); - Value *F = - Intrinsic::getDeclaration(Callee->getParent(), Intrinsic::cttz, ArgType); + Value *F = Intrinsic::getDeclaration(CI->getCalledFunction()->getParent(), + Intrinsic::cttz, ArgType); Value *V = B.CreateCall(F, {Op, B.getTrue()}, "cttz"); V = B.CreateAdd(V, ConstantInt::get(V->getType(), 1)); V = B.CreateIntCast(V, B.getInt32Ty(), false); @@ -1506,6 +1548,18 @@ Value *LibCallSimplifier::optimizeFFS(CallInst *CI, IRBuilder<> &B) { return B.CreateSelect(Cond, V, B.getInt32(0)); } +Value *LibCallSimplifier::optimizeFls(CallInst *CI, IRBuilder<> &B) { + // fls(x) -> (i32)(sizeInBits(x) - llvm.ctlz(x, false)) + Value *Op = CI->getArgOperand(0); + Type *ArgType = Op->getType(); + Value *F = Intrinsic::getDeclaration(CI->getCalledFunction()->getParent(), + Intrinsic::ctlz, ArgType); + Value *V = B.CreateCall(F, {Op, B.getFalse()}, "ctlz"); + V = B.CreateSub(ConstantInt::get(V->getType(), ArgType->getIntegerBitWidth()), + V); + return B.CreateIntCast(V, CI->getType(), false); +} + Value *LibCallSimplifier::optimizeAbs(CallInst *CI, IRBuilder<> &B) { // abs(x) -> x >s -1 ? x : -x Value *Op = CI->getArgOperand(0); @@ -1891,7 +1945,7 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, if (TLI->getLibFunc(*Callee, Func) && TLI->has(Func)) { // Make sure we never change the calling convention. assert((ignoreCallingConv(Func) || - CI->getCallingConv() == llvm::CallingConv::C) && + isCallingConvCCompatible(CI)) && "Optimizing string/memory libcall would change the calling convention"); switch (Func) { case LibFunc::strcat: @@ -1958,7 +2012,7 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { SmallVector<OperandBundleDef, 2> OpBundles; CI->getOperandBundlesAsDefs(OpBundles); IRBuilder<> Builder(CI, /*FPMathTag=*/nullptr, OpBundles); - bool isCallingConvC = CI->getCallingConv() == llvm::CallingConv::C; + bool isCallingConvC = isCallingConvCCompatible(CI); // Command-line parameter overrides instruction attribute. if (EnableUnsafeFPShrink.getNumOccurrences() > 0) @@ -2042,6 +2096,10 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { case LibFunc::ffsl: case LibFunc::ffsll: return optimizeFFS(CI, Builder); + case LibFunc::fls: + case LibFunc::flsl: + case LibFunc::flsll: + return optimizeFls(CI, Builder); case LibFunc::abs: case LibFunc::labs: case LibFunc::llabs: @@ -2314,7 +2372,7 @@ Value *FortifiedLibCallSimplifier::optimizeCall(CallInst *CI) { SmallVector<OperandBundleDef, 2> OpBundles; CI->getOperandBundlesAsDefs(OpBundles); IRBuilder<> Builder(CI, /*FPMathTag=*/nullptr, OpBundles); - bool isCallingConvC = CI->getCallingConv() == llvm::CallingConv::C; + bool isCallingConvC = isCallingConvCCompatible(CI); // First, check that this is a known library functions and that the prototype // is correct. diff --git a/contrib/llvm/lib/Transforms/Utils/StripGCRelocates.cpp b/contrib/llvm/lib/Transforms/Utils/StripGCRelocates.cpp new file mode 100644 index 0000000..f3d3fad --- /dev/null +++ b/contrib/llvm/lib/Transforms/Utils/StripGCRelocates.cpp @@ -0,0 +1,80 @@ +//===- StripGCRelocates.cpp - Remove gc.relocates inserted by RewriteStatePoints===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This is a little utility pass that removes the gc.relocates inserted by +// RewriteStatepointsForGC. Note that the generated IR is incorrect, +// but this is useful as a single pass in itself, for analysis of IR, without +// the GC.relocates. The statepoint and gc.result instrinsics would still be +// present. +//===----------------------------------------------------------------------===// + +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Statepoint.h" +#include "llvm/IR/Type.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +namespace { +struct StripGCRelocates : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + StripGCRelocates() : FunctionPass(ID) { + initializeStripGCRelocatesPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &Info) const override {} + + bool runOnFunction(Function &F) override; + +}; +char StripGCRelocates::ID = 0; +} + +bool StripGCRelocates::runOnFunction(Function &F) { + // Nothing to do for declarations. + if (F.isDeclaration()) + return false; + SmallVector<GCRelocateInst *, 20> GCRelocates; + // TODO: We currently do not handle gc.relocates that are in landing pads, + // i.e. not bound to a single statepoint token. + for (Instruction &I : instructions(F)) { + if (auto *GCR = dyn_cast<GCRelocateInst>(&I)) + if (isStatepoint(GCR->getOperand(0))) + GCRelocates.push_back(GCR); + } + // All gc.relocates are bound to a single statepoint token. The order of + // visiting gc.relocates for deletion does not matter. + for (GCRelocateInst *GCRel : GCRelocates) { + Value *OrigPtr = GCRel->getDerivedPtr(); + Value *ReplaceGCRel = OrigPtr; + + // All gc_relocates are i8 addrspace(1)* typed, we need a bitcast from i8 + // addrspace(1)* to the type of the OrigPtr, if the are not the same. + if (GCRel->getType() != OrigPtr->getType()) + ReplaceGCRel = new BitCastInst(OrigPtr, GCRel->getType(), "cast", GCRel); + + // Replace all uses of gc.relocate and delete the gc.relocate + // There maybe unncessary bitcasts back to the OrigPtr type, an instcombine + // pass would clear this up. + GCRel->replaceAllUsesWith(ReplaceGCRel); + GCRel->eraseFromParent(); + } + return !GCRelocates.empty(); +} + +INITIALIZE_PASS(StripGCRelocates, "strip-gc-relocates", + "Strip gc.relocates inserted through RewriteStatepointsForGC", + true, false) +FunctionPass *llvm::createStripGCRelocatesPass() { + return new StripGCRelocates(); +} diff --git a/contrib/llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp b/contrib/llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp new file mode 100644 index 0000000..66dbf33 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Utils/StripNonLineTableDebugInfo.cpp @@ -0,0 +1,42 @@ +//===- StripNonLineTableDebugInfo.cpp -- Strip parts of Debug Info --------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO.h" +#include "llvm/IR/DebugInfo.h" +#include "llvm/Pass.h" +using namespace llvm; + +namespace { + +/// This pass strips all debug info that is not related line tables. +/// The result will be the same as if the program where compiled with +/// -gline-tables-only. +struct StripNonLineTableDebugInfo : public ModulePass { + static char ID; // Pass identification, replacement for typeid + StripNonLineTableDebugInfo() : ModulePass(ID) { + initializeStripNonLineTableDebugInfoPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } + + bool runOnModule(Module &M) override { + return llvm::stripNonLineTableDebugInfo(M); + } +}; +} + +char StripNonLineTableDebugInfo::ID = 0; +INITIALIZE_PASS(StripNonLineTableDebugInfo, "strip-nonlinetable-debuginfo", + "Strip all debug info except linetables", false, false) + +ModulePass *llvm::createStripNonLineTableDebugInfoPass() { + return new StripNonLineTableDebugInfo(); +} diff --git a/contrib/llvm/lib/Transforms/Utils/SymbolRewriter.cpp b/contrib/llvm/lib/Transforms/Utils/SymbolRewriter.cpp index 7523ca5..6d13663 100644 --- a/contrib/llvm/lib/Transforms/Utils/SymbolRewriter.cpp +++ b/contrib/llvm/lib/Transforms/Utils/SymbolRewriter.cpp @@ -58,6 +58,7 @@ //===----------------------------------------------------------------------===// #define DEBUG_TYPE "symbol-rewriter" +#include "llvm/Transforms/Utils/SymbolRewriter.h" #include "llvm/Pass.h" #include "llvm/ADT/SmallString.h" #include "llvm/IR/LegacyPassManager.h" @@ -68,7 +69,6 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/YAMLParser.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Utils/SymbolRewriter.h" using namespace llvm; using namespace SymbolRewriter; @@ -361,9 +361,11 @@ parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, // TODO see if there is a more elegant solution to selecting the rewrite // descriptor type if (!Target.empty()) - DL->push_back(new ExplicitRewriteFunctionDescriptor(Source, Target, Naked)); + DL->push_back(llvm::make_unique<ExplicitRewriteFunctionDescriptor>( + Source, Target, Naked)); else - DL->push_back(new PatternRewriteFunctionDescriptor(Source, Transform)); + DL->push_back( + llvm::make_unique<PatternRewriteFunctionDescriptor>(Source, Transform)); return true; } @@ -421,11 +423,12 @@ parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, } if (!Target.empty()) - DL->push_back(new ExplicitRewriteGlobalVariableDescriptor(Source, Target, - /*Naked*/false)); + DL->push_back(llvm::make_unique<ExplicitRewriteGlobalVariableDescriptor>( + Source, Target, + /*Naked*/ false)); else - DL->push_back(new PatternRewriteGlobalVariableDescriptor(Source, - Transform)); + DL->push_back(llvm::make_unique<PatternRewriteGlobalVariableDescriptor>( + Source, Transform)); return true; } @@ -483,67 +486,80 @@ parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, } if (!Target.empty()) - DL->push_back(new ExplicitRewriteNamedAliasDescriptor(Source, Target, - /*Naked*/false)); + DL->push_back(llvm::make_unique<ExplicitRewriteNamedAliasDescriptor>( + Source, Target, + /*Naked*/ false)); else - DL->push_back(new PatternRewriteNamedAliasDescriptor(Source, Transform)); + DL->push_back(llvm::make_unique<PatternRewriteNamedAliasDescriptor>( + Source, Transform)); return true; } namespace { -class RewriteSymbols : public ModulePass { +class RewriteSymbolsLegacyPass : public ModulePass { public: static char ID; // Pass identification, replacement for typeid - RewriteSymbols(); - RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL); + RewriteSymbolsLegacyPass(); + RewriteSymbolsLegacyPass(SymbolRewriter::RewriteDescriptorList &DL); bool runOnModule(Module &M) override; private: - void loadAndParseMapFiles(); - - SymbolRewriter::RewriteDescriptorList Descriptors; + RewriteSymbolPass Impl; }; -char RewriteSymbols::ID = 0; +char RewriteSymbolsLegacyPass::ID = 0; -RewriteSymbols::RewriteSymbols() : ModulePass(ID) { - initializeRewriteSymbolsPass(*PassRegistry::getPassRegistry()); - loadAndParseMapFiles(); +RewriteSymbolsLegacyPass::RewriteSymbolsLegacyPass() : ModulePass(ID), Impl() { + initializeRewriteSymbolsLegacyPassPass(*PassRegistry::getPassRegistry()); } -RewriteSymbols::RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL) - : ModulePass(ID) { - Descriptors.splice(Descriptors.begin(), DL); +RewriteSymbolsLegacyPass::RewriteSymbolsLegacyPass( + SymbolRewriter::RewriteDescriptorList &DL) + : ModulePass(ID), Impl(DL) {} + +bool RewriteSymbolsLegacyPass::runOnModule(Module &M) { + return Impl.runImpl(M); +} } -bool RewriteSymbols::runOnModule(Module &M) { +namespace llvm { +PreservedAnalyses RewriteSymbolPass::run(Module &M, ModuleAnalysisManager &AM) { + if (!runImpl(M)) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} + +bool RewriteSymbolPass::runImpl(Module &M) { bool Changed; Changed = false; for (auto &Descriptor : Descriptors) - Changed |= Descriptor.performOnModule(M); + Changed |= Descriptor->performOnModule(M); return Changed; } -void RewriteSymbols::loadAndParseMapFiles() { +void RewriteSymbolPass::loadAndParseMapFiles() { const std::vector<std::string> MapFiles(RewriteMapFiles); - SymbolRewriter::RewriteMapParser parser; + SymbolRewriter::RewriteMapParser Parser; for (const auto &MapFile : MapFiles) - parser.parse(MapFile, &Descriptors); + Parser.parse(MapFile, &Descriptors); } } -INITIALIZE_PASS(RewriteSymbols, "rewrite-symbols", "Rewrite Symbols", false, - false) +INITIALIZE_PASS(RewriteSymbolsLegacyPass, "rewrite-symbols", "Rewrite Symbols", + false, false) -ModulePass *llvm::createRewriteSymbolsPass() { return new RewriteSymbols(); } +ModulePass *llvm::createRewriteSymbolsPass() { + return new RewriteSymbolsLegacyPass(); +} ModulePass * llvm::createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList &DL) { - return new RewriteSymbols(DL); + return new RewriteSymbolsLegacyPass(DL); } diff --git a/contrib/llvm/lib/Transforms/Utils/Utils.cpp b/contrib/llvm/lib/Transforms/Utils/Utils.cpp index 8f85f19..7b9de2e 100644 --- a/contrib/llvm/lib/Transforms/Utils/Utils.cpp +++ b/contrib/llvm/lib/Transforms/Utils/Utils.cpp @@ -25,16 +25,19 @@ void llvm::initializeTransformUtils(PassRegistry &Registry) { initializeBreakCriticalEdgesPass(Registry); initializeInstNamerPass(Registry); initializeLCSSAWrapperPassPass(Registry); + initializeLibCallsShrinkWrapLegacyPassPass(Registry); initializeLoopSimplifyPass(Registry); - initializeLowerInvokePass(Registry); + initializeLowerInvokeLegacyPassPass(Registry); initializeLowerSwitchPass(Registry); - initializeNameAnonFunctionPass(Registry); + initializeNameAnonGlobalLegacyPassPass(Registry); initializePromoteLegacyPassPass(Registry); + initializeStripNonLineTableDebugInfoPass(Registry); initializeUnifyFunctionExitNodesPass(Registry); initializeInstSimplifierPass(Registry); initializeMetaRenamerPass(Registry); initializeMemorySSAWrapperPassPass(Registry); initializeMemorySSAPrinterLegacyPassPass(Registry); + initializeStripGCRelocatesPass(Registry); } /// LLVMInitializeTransformUtils - C binding for initializeTransformUtilsPasses. diff --git a/contrib/llvm/lib/Transforms/Utils/ValueMapper.cpp b/contrib/llvm/lib/Transforms/Utils/ValueMapper.cpp index 2eade8c..0e9baaf 100644 --- a/contrib/llvm/lib/Transforms/Utils/ValueMapper.cpp +++ b/contrib/llvm/lib/Transforms/Utils/ValueMapper.cpp @@ -38,15 +38,6 @@ struct DelayedBasicBlock { BasicBlock *OldBB; std::unique_ptr<BasicBlock> TempBB; - // Explicit move for MSVC. - DelayedBasicBlock(DelayedBasicBlock &&X) - : OldBB(std::move(X.OldBB)), TempBB(std::move(X.TempBB)) {} - DelayedBasicBlock &operator=(DelayedBasicBlock &&X) { - OldBB = std::move(X.OldBB); - TempBB = std::move(X.TempBB); - return *this; - } - DelayedBasicBlock(const BlockAddress &Old) : OldBB(Old.getBasicBlock()), TempBB(BasicBlock::Create(Old.getContext())) {} @@ -184,17 +175,6 @@ class MDNodeMapper { bool HasChanged = false; unsigned ID = ~0u; TempMDNode Placeholder; - - Data() {} - Data(Data &&X) - : HasChanged(std::move(X.HasChanged)), ID(std::move(X.ID)), - Placeholder(std::move(X.Placeholder)) {} - Data &operator=(Data &&X) { - HasChanged = std::move(X.HasChanged); - ID = std::move(X.ID); - Placeholder = std::move(X.Placeholder); - return *this; - } }; /// A graph of uniqued nodes. @@ -671,7 +651,7 @@ void MDNodeMapper::UniquedGraph::propagateChanges() { if (D.HasChanged) continue; - if (!llvm::any_of(N->operands(), [&](const Metadata *Op) { + if (none_of(N->operands(), [&](const Metadata *Op) { auto Where = Info.find(Op); return Where != Info.end() && Where->second.HasChanged; })) diff --git a/contrib/llvm/lib/Transforms/Vectorize/BBVectorize.cpp b/contrib/llvm/lib/Transforms/Vectorize/BBVectorize.cpp index af594cb..c01740b 100644 --- a/contrib/llvm/lib/Transforms/Vectorize/BBVectorize.cpp +++ b/contrib/llvm/lib/Transforms/Vectorize/BBVectorize.cpp @@ -3148,7 +3148,7 @@ namespace { LLVMContext::MD_noalias, LLVMContext::MD_fpmath, LLVMContext::MD_invariant_group}; combineMetadata(K, H, KnownIDs); - K->intersectOptionalDataWith(H); + K->andIRFlags(H); for (unsigned o = 0; o < NumOperands; ++o) K->setOperand(o, ReplacedOperands[o]); diff --git a/contrib/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/contrib/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index c8906bd..c44a393 100644 --- a/contrib/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/contrib/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Triple.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/OrderedBasicBlock.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -30,6 +31,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Vectorize.h" using namespace llvm; @@ -40,13 +42,12 @@ STATISTIC(NumScalarsVectorized, "Number of scalar accesses vectorized"); namespace { -// TODO: Remove this -static const unsigned TargetBaseAlign = 4; +// FIXME: Assuming stack alignment of 4 is always good enough +static const unsigned StackAdjustedAlignment = 4; +typedef SmallVector<Instruction *, 8> InstrList; +typedef MapVector<Value *, InstrList> InstrListMap; class Vectorizer { - typedef SmallVector<Value *, 8> ValueList; - typedef MapVector<Value *, ValueList> ValueListMap; - Function &F; AliasAnalysis &AA; DominatorTree &DT; @@ -54,8 +55,6 @@ class Vectorizer { TargetTransformInfo &TTI; const DataLayout &DL; IRBuilder<> Builder; - ValueListMap StoreRefs; - ValueListMap LoadRefs; public: Vectorizer(Function &F, AliasAnalysis &AA, DominatorTree &DT, @@ -94,45 +93,47 @@ private: /// Returns the first and the last instructions in Chain. std::pair<BasicBlock::iterator, BasicBlock::iterator> - getBoundaryInstrs(ArrayRef<Value *> Chain); + getBoundaryInstrs(ArrayRef<Instruction *> Chain); /// Erases the original instructions after vectorizing. - void eraseInstructions(ArrayRef<Value *> Chain); + void eraseInstructions(ArrayRef<Instruction *> Chain); /// "Legalize" the vector type that would be produced by combining \p /// ElementSizeBits elements in \p Chain. Break into two pieces such that the /// total size of each piece is 1, 2 or a multiple of 4 bytes. \p Chain is /// expected to have more than 4 elements. - std::pair<ArrayRef<Value *>, ArrayRef<Value *>> - splitOddVectorElts(ArrayRef<Value *> Chain, unsigned ElementSizeBits); + std::pair<ArrayRef<Instruction *>, ArrayRef<Instruction *>> + splitOddVectorElts(ArrayRef<Instruction *> Chain, unsigned ElementSizeBits); - /// Checks for instructions which may affect the memory accessed - /// in the chain between \p From and \p To. Returns Index, where - /// \p Chain[0, Index) is the largest vectorizable chain prefix. - /// The elements of \p Chain should be all loads or all stores. - unsigned getVectorizablePrefixEndIdx(ArrayRef<Value *> Chain, - BasicBlock::iterator From, - BasicBlock::iterator To); + /// Finds the largest prefix of Chain that's vectorizable, checking for + /// intervening instructions which may affect the memory accessed by the + /// instructions within Chain. + /// + /// The elements of \p Chain must be all loads or all stores and must be in + /// address order. + ArrayRef<Instruction *> getVectorizablePrefix(ArrayRef<Instruction *> Chain); /// Collects load and store instructions to vectorize. - void collectInstructions(BasicBlock *BB); + std::pair<InstrListMap, InstrListMap> collectInstructions(BasicBlock *BB); - /// Processes the collected instructions, the \p Map. The elements of \p Map + /// Processes the collected instructions, the \p Map. The values of \p Map /// should be all loads or all stores. - bool vectorizeChains(ValueListMap &Map); + bool vectorizeChains(InstrListMap &Map); /// Finds the load/stores to consecutive memory addresses and vectorizes them. - bool vectorizeInstructions(ArrayRef<Value *> Instrs); + bool vectorizeInstructions(ArrayRef<Instruction *> Instrs); /// Vectorizes the load instructions in Chain. - bool vectorizeLoadChain(ArrayRef<Value *> Chain, - SmallPtrSet<Value *, 16> *InstructionsProcessed); + bool + vectorizeLoadChain(ArrayRef<Instruction *> Chain, + SmallPtrSet<Instruction *, 16> *InstructionsProcessed); /// Vectorizes the store instructions in Chain. - bool vectorizeStoreChain(ArrayRef<Value *> Chain, - SmallPtrSet<Value *, 16> *InstructionsProcessed); + bool + vectorizeStoreChain(ArrayRef<Instruction *> Chain, + SmallPtrSet<Instruction *, 16> *InstructionsProcessed); - /// Check if this load/store access is misaligned accesses + /// Check if this load/store access is misaligned accesses. bool accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace, unsigned Alignment); }; @@ -147,7 +148,7 @@ public: bool runOnFunction(Function &F) override; - const char *getPassName() const override { + StringRef getPassName() const override { return "GPU Load and Store Vectorizer"; } @@ -177,6 +178,13 @@ Pass *llvm::createLoadStoreVectorizerPass() { return new LoadStoreVectorizer(); } +// The real propagateMetadata expects a SmallVector<Value*>, but we deal in +// vectors of Instructions. +static void propagateMetadata(Instruction *I, ArrayRef<Instruction *> IL) { + SmallVector<Value *, 8> VL(IL.begin(), IL.end()); + propagateMetadata(I, VL); +} + bool LoadStoreVectorizer::runOnFunction(Function &F) { // Don't vectorize when the attribute NoImplicitFloat is used. if (skipFunction(F) || F.hasFnAttribute(Attribute::NoImplicitFloat)) @@ -198,7 +206,8 @@ bool Vectorizer::run() { // Scan the blocks in the function in post order. for (BasicBlock *BB : post_order(&F)) { - collectInstructions(BB); + InstrListMap LoadRefs, StoreRefs; + std::tie(LoadRefs, StoreRefs) = collectInstructions(BB); Changed |= vectorizeChains(LoadRefs); Changed |= vectorizeChains(StoreRefs); } @@ -338,6 +347,7 @@ bool Vectorizer::isConsecutiveAccess(Value *A, Value *B) { } void Vectorizer::reorder(Instruction *I) { + OrderedBasicBlock OBB(I->getParent()); SmallPtrSet<Instruction *, 16> InstructionsToMove; SmallVector<Instruction *, 16> Worklist; @@ -350,11 +360,14 @@ void Vectorizer::reorder(Instruction *I) { if (!IM || IM->getOpcode() == Instruction::PHI) continue; - if (!DT.dominates(IM, I)) { + // If IM is in another BB, no need to move it, because this pass only + // vectorizes instructions within one BB. + if (IM->getParent() != I->getParent()) + continue; + + if (!OBB.dominates(IM, I)) { InstructionsToMove.insert(IM); Worklist.push_back(IM); - assert(IM->getParent() == IW->getParent() && - "Instructions to move should be in the same basic block"); } } } @@ -362,7 +375,7 @@ void Vectorizer::reorder(Instruction *I) { // All instructions to move should follow I. Start from I, not from begin(). for (auto BBI = I->getIterator(), E = I->getParent()->end(); BBI != E; ++BBI) { - if (!is_contained(InstructionsToMove, &*BBI)) + if (!InstructionsToMove.count(&*BBI)) continue; Instruction *IM = &*BBI; --BBI; @@ -372,8 +385,8 @@ void Vectorizer::reorder(Instruction *I) { } std::pair<BasicBlock::iterator, BasicBlock::iterator> -Vectorizer::getBoundaryInstrs(ArrayRef<Value *> Chain) { - Instruction *C0 = cast<Instruction>(Chain[0]); +Vectorizer::getBoundaryInstrs(ArrayRef<Instruction *> Chain) { + Instruction *C0 = Chain[0]; BasicBlock::iterator FirstInstr = C0->getIterator(); BasicBlock::iterator LastInstr = C0->getIterator(); @@ -397,105 +410,152 @@ Vectorizer::getBoundaryInstrs(ArrayRef<Value *> Chain) { return std::make_pair(FirstInstr, ++LastInstr); } -void Vectorizer::eraseInstructions(ArrayRef<Value *> Chain) { +void Vectorizer::eraseInstructions(ArrayRef<Instruction *> Chain) { SmallVector<Instruction *, 16> Instrs; - for (Value *V : Chain) { - Value *PtrOperand = getPointerOperand(V); + for (Instruction *I : Chain) { + Value *PtrOperand = getPointerOperand(I); assert(PtrOperand && "Instruction must have a pointer operand."); - Instrs.push_back(cast<Instruction>(V)); + Instrs.push_back(I); if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(PtrOperand)) Instrs.push_back(GEP); } // Erase instructions. - for (Value *V : Instrs) { - Instruction *Instr = cast<Instruction>(V); - if (Instr->use_empty()) - Instr->eraseFromParent(); - } + for (Instruction *I : Instrs) + if (I->use_empty()) + I->eraseFromParent(); } -std::pair<ArrayRef<Value *>, ArrayRef<Value *>> -Vectorizer::splitOddVectorElts(ArrayRef<Value *> Chain, +std::pair<ArrayRef<Instruction *>, ArrayRef<Instruction *>> +Vectorizer::splitOddVectorElts(ArrayRef<Instruction *> Chain, unsigned ElementSizeBits) { - unsigned ElemSizeInBytes = ElementSizeBits / 8; - unsigned SizeInBytes = ElemSizeInBytes * Chain.size(); - unsigned NumRight = (SizeInBytes % 4) / ElemSizeInBytes; - unsigned NumLeft = Chain.size() - NumRight; + unsigned ElementSizeBytes = ElementSizeBits / 8; + unsigned SizeBytes = ElementSizeBytes * Chain.size(); + unsigned NumLeft = (SizeBytes - (SizeBytes % 4)) / ElementSizeBytes; + if (NumLeft == Chain.size()) + --NumLeft; + else if (NumLeft == 0) + NumLeft = 1; return std::make_pair(Chain.slice(0, NumLeft), Chain.slice(NumLeft)); } -unsigned Vectorizer::getVectorizablePrefixEndIdx(ArrayRef<Value *> Chain, - BasicBlock::iterator From, - BasicBlock::iterator To) { - SmallVector<std::pair<Value *, unsigned>, 16> MemoryInstrs; - SmallVector<std::pair<Value *, unsigned>, 16> ChainInstrs; +ArrayRef<Instruction *> +Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) { + // These are in BB order, unlike Chain, which is in address order. + SmallVector<Instruction *, 16> MemoryInstrs; + SmallVector<Instruction *, 16> ChainInstrs; + + bool IsLoadChain = isa<LoadInst>(Chain[0]); + DEBUG({ + for (Instruction *I : Chain) { + if (IsLoadChain) + assert(isa<LoadInst>(I) && + "All elements of Chain must be loads, or all must be stores."); + else + assert(isa<StoreInst>(I) && + "All elements of Chain must be loads, or all must be stores."); + } + }); - unsigned InstrIdx = 0; - for (auto I = From; I != To; ++I, ++InstrIdx) { + for (Instruction &I : make_range(getBoundaryInstrs(Chain))) { if (isa<LoadInst>(I) || isa<StoreInst>(I)) { - if (!is_contained(Chain, &*I)) - MemoryInstrs.push_back({&*I, InstrIdx}); + if (!is_contained(Chain, &I)) + MemoryInstrs.push_back(&I); else - ChainInstrs.push_back({&*I, InstrIdx}); - } else if (I->mayHaveSideEffects()) { - DEBUG(dbgs() << "LSV: Found side-effecting operation: " << *I << '\n'); - return 0; + ChainInstrs.push_back(&I); + } else if (IsLoadChain && (I.mayWriteToMemory() || I.mayThrow())) { + DEBUG(dbgs() << "LSV: Found may-write/throw operation: " << I << '\n'); + break; + } else if (!IsLoadChain && (I.mayReadOrWriteMemory() || I.mayThrow())) { + DEBUG(dbgs() << "LSV: Found may-read/write/throw operation: " << I + << '\n'); + break; } } - assert(Chain.size() == ChainInstrs.size() && - "All instructions in the Chain must exist in [From, To)."); + OrderedBasicBlock OBB(Chain[0]->getParent()); - unsigned ChainIdx = 0; - for (auto EntryChain : ChainInstrs) { - Value *ChainInstrValue = EntryChain.first; - unsigned ChainInstrIdx = EntryChain.second; - for (auto EntryMem : MemoryInstrs) { - Value *MemInstrValue = EntryMem.first; - unsigned MemInstrIdx = EntryMem.second; - if (isa<LoadInst>(MemInstrValue) && isa<LoadInst>(ChainInstrValue)) + // Loop until we find an instruction in ChainInstrs that we can't vectorize. + unsigned ChainInstrIdx = 0; + Instruction *BarrierMemoryInstr = nullptr; + + for (unsigned E = ChainInstrs.size(); ChainInstrIdx < E; ++ChainInstrIdx) { + Instruction *ChainInstr = ChainInstrs[ChainInstrIdx]; + + // If a barrier memory instruction was found, chain instructions that follow + // will not be added to the valid prefix. + if (BarrierMemoryInstr && OBB.dominates(BarrierMemoryInstr, ChainInstr)) + break; + + // Check (in BB order) if any instruction prevents ChainInstr from being + // vectorized. Find and store the first such "conflicting" instruction. + for (Instruction *MemInstr : MemoryInstrs) { + // If a barrier memory instruction was found, do not check past it. + if (BarrierMemoryInstr && OBB.dominates(BarrierMemoryInstr, MemInstr)) + break; + + if (isa<LoadInst>(MemInstr) && isa<LoadInst>(ChainInstr)) continue; // We can ignore the alias as long as the load comes before the store, // because that means we won't be moving the load past the store to // vectorize it (the vectorized load is inserted at the location of the // first load in the chain). - if (isa<StoreInst>(MemInstrValue) && isa<LoadInst>(ChainInstrValue) && - ChainInstrIdx < MemInstrIdx) + if (isa<StoreInst>(MemInstr) && isa<LoadInst>(ChainInstr) && + OBB.dominates(ChainInstr, MemInstr)) continue; // Same case, but in reverse. - if (isa<LoadInst>(MemInstrValue) && isa<StoreInst>(ChainInstrValue) && - ChainInstrIdx > MemInstrIdx) + if (isa<LoadInst>(MemInstr) && isa<StoreInst>(ChainInstr) && + OBB.dominates(MemInstr, ChainInstr)) continue; - Instruction *M0 = cast<Instruction>(MemInstrValue); - Instruction *M1 = cast<Instruction>(ChainInstrValue); - - if (!AA.isNoAlias(MemoryLocation::get(M0), MemoryLocation::get(M1))) { + if (!AA.isNoAlias(MemoryLocation::get(MemInstr), + MemoryLocation::get(ChainInstr))) { DEBUG({ - Value *Ptr0 = getPointerOperand(M0); - Value *Ptr1 = getPointerOperand(M1); - - dbgs() << "LSV: Found alias.\n" - " Aliasing instruction and pointer:\n" - << *MemInstrValue << " aliases " << *Ptr0 << '\n' - << " Aliased instruction and pointer:\n" - << *ChainInstrValue << " aliases " << *Ptr1 << '\n'; + dbgs() << "LSV: Found alias:\n" + " Aliasing instruction and pointer:\n" + << " " << *MemInstr << '\n' + << " " << *getPointerOperand(MemInstr) << '\n' + << " Aliased instruction and pointer:\n" + << " " << *ChainInstr << '\n' + << " " << *getPointerOperand(ChainInstr) << '\n'; }); - - return ChainIdx; + // Save this aliasing memory instruction as a barrier, but allow other + // instructions that precede the barrier to be vectorized with this one. + BarrierMemoryInstr = MemInstr; + break; } } - ChainIdx++; + // Continue the search only for store chains, since vectorizing stores that + // precede an aliasing load is valid. Conversely, vectorizing loads is valid + // up to an aliasing store, but should not pull loads from further down in + // the basic block. + if (IsLoadChain && BarrierMemoryInstr) { + // The BarrierMemoryInstr is a store that precedes ChainInstr. + assert(OBB.dominates(BarrierMemoryInstr, ChainInstr)); + break; + } } - return Chain.size(); + + // Find the largest prefix of Chain whose elements are all in + // ChainInstrs[0, ChainInstrIdx). This is the largest vectorizable prefix of + // Chain. (Recall that Chain is in address order, but ChainInstrs is in BB + // order.) + SmallPtrSet<Instruction *, 8> VectorizableChainInstrs( + ChainInstrs.begin(), ChainInstrs.begin() + ChainInstrIdx); + unsigned ChainIdx = 0; + for (unsigned ChainLen = Chain.size(); ChainIdx < ChainLen; ++ChainIdx) { + if (!VectorizableChainInstrs.count(Chain[ChainIdx])) + break; + } + return Chain.slice(0, ChainIdx); } -void Vectorizer::collectInstructions(BasicBlock *BB) { - LoadRefs.clear(); - StoreRefs.clear(); +std::pair<InstrListMap, InstrListMap> +Vectorizer::collectInstructions(BasicBlock *BB) { + InstrListMap LoadRefs; + InstrListMap StoreRefs; for (Instruction &I : *BB) { if (!I.mayReadOrWriteMemory()) @@ -505,6 +565,10 @@ void Vectorizer::collectInstructions(BasicBlock *BB) { if (!LI->isSimple()) continue; + // Skip if it's not legal. + if (!TTI.isLegalToVectorizeLoad(LI)) + continue; + Type *Ty = LI->getType(); if (!VectorType::isValidElementType(Ty->getScalarType())) continue; @@ -525,14 +589,11 @@ void Vectorizer::collectInstructions(BasicBlock *BB) { // Make sure all the users of a vector are constant-index extracts. if (isa<VectorType>(Ty) && !all_of(LI->users(), [LI](const User *U) { - const Instruction *UI = cast<Instruction>(U); - return isa<ExtractElementInst>(UI) && - isa<ConstantInt>(UI->getOperand(1)); + const ExtractElementInst *EEI = dyn_cast<ExtractElementInst>(U); + return EEI && isa<ConstantInt>(EEI->getOperand(1)); })) continue; - // TODO: Target hook to filter types. - // Save the load locations. Value *ObjPtr = GetUnderlyingObject(Ptr, DL); LoadRefs[ObjPtr].push_back(LI); @@ -541,6 +602,10 @@ void Vectorizer::collectInstructions(BasicBlock *BB) { if (!SI->isSimple()) continue; + // Skip if it's not legal. + if (!TTI.isLegalToVectorizeStore(SI)) + continue; + Type *Ty = SI->getValueOperand()->getType(); if (!VectorType::isValidElementType(Ty->getScalarType())) continue; @@ -558,9 +623,8 @@ void Vectorizer::collectInstructions(BasicBlock *BB) { continue; if (isa<VectorType>(Ty) && !all_of(SI->users(), [SI](const User *U) { - const Instruction *UI = cast<Instruction>(U); - return isa<ExtractElementInst>(UI) && - isa<ConstantInt>(UI->getOperand(1)); + const ExtractElementInst *EEI = dyn_cast<ExtractElementInst>(U); + return EEI && isa<ConstantInt>(EEI->getOperand(1)); })) continue; @@ -569,12 +633,14 @@ void Vectorizer::collectInstructions(BasicBlock *BB) { StoreRefs[ObjPtr].push_back(SI); } } + + return {LoadRefs, StoreRefs}; } -bool Vectorizer::vectorizeChains(ValueListMap &Map) { +bool Vectorizer::vectorizeChains(InstrListMap &Map) { bool Changed = false; - for (const std::pair<Value *, ValueList> &Chain : Map) { + for (const std::pair<Value *, InstrList> &Chain : Map) { unsigned Size = Chain.second.size(); if (Size < 2) continue; @@ -584,7 +650,7 @@ bool Vectorizer::vectorizeChains(ValueListMap &Map) { // Process the stores in chunks of 64. for (unsigned CI = 0, CE = Size; CI < CE; CI += 64) { unsigned Len = std::min<unsigned>(CE - CI, 64); - ArrayRef<Value *> Chunk(&Chain.second[CI], Len); + ArrayRef<Instruction *> Chunk(&Chain.second[CI], Len); Changed |= vectorizeInstructions(Chunk); } } @@ -592,9 +658,9 @@ bool Vectorizer::vectorizeChains(ValueListMap &Map) { return Changed; } -bool Vectorizer::vectorizeInstructions(ArrayRef<Value *> Instrs) { +bool Vectorizer::vectorizeInstructions(ArrayRef<Instruction *> Instrs) { DEBUG(dbgs() << "LSV: Vectorizing " << Instrs.size() << " instructions.\n"); - SmallSetVector<int, 16> Heads, Tails; + SmallVector<int, 16> Heads, Tails; int ConsecutiveChain[64]; // Do a quadratic search on all of the given stores and find all of the pairs @@ -613,34 +679,34 @@ bool Vectorizer::vectorizeInstructions(ArrayRef<Value *> Instrs) { continue; // Should not insert. } - Tails.insert(j); - Heads.insert(i); + Tails.push_back(j); + Heads.push_back(i); ConsecutiveChain[i] = j; } } } bool Changed = false; - SmallPtrSet<Value *, 16> InstructionsProcessed; + SmallPtrSet<Instruction *, 16> InstructionsProcessed; for (int Head : Heads) { if (InstructionsProcessed.count(Instrs[Head])) continue; - bool longerChainExists = false; + bool LongerChainExists = false; for (unsigned TIt = 0; TIt < Tails.size(); TIt++) if (Head == Tails[TIt] && !InstructionsProcessed.count(Instrs[Heads[TIt]])) { - longerChainExists = true; + LongerChainExists = true; break; } - if (longerChainExists) + if (LongerChainExists) continue; // We found an instr that starts a chain. Now follow the chain and try to // vectorize it. - SmallVector<Value *, 16> Operands; + SmallVector<Instruction *, 16> Operands; int I = Head; - while (I != -1 && (Tails.count(I) || Heads.count(I))) { + while (I != -1 && (is_contained(Tails, I) || is_contained(Heads, I))) { if (InstructionsProcessed.count(Instrs[I])) break; @@ -661,13 +727,14 @@ bool Vectorizer::vectorizeInstructions(ArrayRef<Value *> Instrs) { } bool Vectorizer::vectorizeStoreChain( - ArrayRef<Value *> Chain, SmallPtrSet<Value *, 16> *InstructionsProcessed) { + ArrayRef<Instruction *> Chain, + SmallPtrSet<Instruction *, 16> *InstructionsProcessed) { StoreInst *S0 = cast<StoreInst>(Chain[0]); // If the vector has an int element, default to int for the whole load. Type *StoreTy; - for (const auto &V : Chain) { - StoreTy = cast<StoreInst>(V)->getValueOperand()->getType(); + for (Instruction *I : Chain) { + StoreTy = cast<StoreInst>(I)->getValueOperand()->getType(); if (StoreTy->isIntOrIntVectorTy()) break; @@ -683,40 +750,34 @@ bool Vectorizer::vectorizeStoreChain( unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); unsigned VF = VecRegSize / Sz; unsigned ChainSize = Chain.size(); + unsigned Alignment = getAlignment(S0); if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) { InstructionsProcessed->insert(Chain.begin(), Chain.end()); return false; } - BasicBlock::iterator First, Last; - std::tie(First, Last) = getBoundaryInstrs(Chain); - unsigned StopChain = getVectorizablePrefixEndIdx(Chain, First, Last); - if (StopChain == 0) { - // There exists a side effect instruction, no vectorization possible. + ArrayRef<Instruction *> NewChain = getVectorizablePrefix(Chain); + if (NewChain.empty()) { + // No vectorization possible. InstructionsProcessed->insert(Chain.begin(), Chain.end()); return false; } - if (StopChain == 1) { + if (NewChain.size() == 1) { // Failed after the first instruction. Discard it and try the smaller chain. - InstructionsProcessed->insert(Chain.front()); + InstructionsProcessed->insert(NewChain.front()); return false; } // Update Chain to the valid vectorizable subchain. - Chain = Chain.slice(0, StopChain); + Chain = NewChain; ChainSize = Chain.size(); - // Store size should be 1B, 2B or multiple of 4B. - // TODO: Target hook for size constraint? - unsigned SzInBytes = (Sz / 8) * ChainSize; - if (SzInBytes > 2 && SzInBytes % 4 != 0) { - DEBUG(dbgs() << "LSV: Size should be 1B, 2B " - "or multiple of 4B. Splitting.\n"); - if (SzInBytes == 3) - return vectorizeStoreChain(Chain.slice(0, ChainSize - 1), - InstructionsProcessed); - + // Check if it's legal to vectorize this chain. If not, split the chain and + // try again. + unsigned EltSzInBytes = Sz / 8; + unsigned SzInBytes = EltSzInBytes * ChainSize; + if (!TTI.isLegalToVectorizeStoreChain(SzInBytes, Alignment, AS)) { auto Chains = splitOddVectorElts(Chain, Sz); return vectorizeStoreChain(Chains.first, InstructionsProcessed) | vectorizeStoreChain(Chains.second, InstructionsProcessed); @@ -730,45 +791,41 @@ bool Vectorizer::vectorizeStoreChain( else VecTy = VectorType::get(StoreTy, Chain.size()); - // If it's more than the max vector size, break it into two pieces. - // TODO: Target hook to control types to split to. - if (ChainSize > VF) { - DEBUG(dbgs() << "LSV: Vector factor is too big." + // If it's more than the max vector size or the target has a better + // vector factor, break it into two pieces. + unsigned TargetVF = TTI.getStoreVectorFactor(VF, Sz, SzInBytes, VecTy); + if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) { + DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor." " Creating two separate arrays.\n"); - return vectorizeStoreChain(Chain.slice(0, VF), InstructionsProcessed) | - vectorizeStoreChain(Chain.slice(VF), InstructionsProcessed); + return vectorizeStoreChain(Chain.slice(0, TargetVF), + InstructionsProcessed) | + vectorizeStoreChain(Chain.slice(TargetVF), InstructionsProcessed); } DEBUG({ dbgs() << "LSV: Stores to vectorize:\n"; - for (Value *V : Chain) - V->dump(); + for (Instruction *I : Chain) + dbgs() << " " << *I << "\n"; }); // We won't try again to vectorize the elements of the chain, regardless of // whether we succeed below. InstructionsProcessed->insert(Chain.begin(), Chain.end()); - // Check alignment restrictions. - unsigned Alignment = getAlignment(S0); - // If the store is going to be misaligned, don't vectorize it. if (accessIsMisaligned(SzInBytes, AS, Alignment)) { if (S0->getPointerAddressSpace() != 0) return false; - // If we're storing to an object on the stack, we control its alignment, - // so we can cheat and change it! - Value *V = GetUnderlyingObject(S0->getPointerOperand(), DL); - if (AllocaInst *AI = dyn_cast_or_null<AllocaInst>(V)) { - AI->setAlignment(TargetBaseAlign); - Alignment = TargetBaseAlign; - } else { + unsigned NewAlign = getOrEnforceKnownAlignment(S0->getPointerOperand(), + StackAdjustedAlignment, + DL, S0, nullptr, &DT); + if (NewAlign < StackAdjustedAlignment) return false; - } } - // Set insert point. + BasicBlock::iterator First, Last; + std::tie(First, Last) = getBoundaryInstrs(Chain); Builder.SetInsertPoint(&*Last); Value *Vec = UndefValue::get(VecTy); @@ -803,9 +860,11 @@ bool Vectorizer::vectorizeStoreChain( } } - Value *Bitcast = - Builder.CreateBitCast(S0->getPointerOperand(), VecTy->getPointerTo(AS)); - StoreInst *SI = cast<StoreInst>(Builder.CreateStore(Vec, Bitcast)); + // This cast is safe because Builder.CreateStore() always creates a bona fide + // StoreInst. + StoreInst *SI = cast<StoreInst>( + Builder.CreateStore(Vec, Builder.CreateBitCast(S0->getPointerOperand(), + VecTy->getPointerTo(AS)))); propagateMetadata(SI, Chain); SI->setAlignment(Alignment); @@ -816,7 +875,8 @@ bool Vectorizer::vectorizeStoreChain( } bool Vectorizer::vectorizeLoadChain( - ArrayRef<Value *> Chain, SmallPtrSet<Value *, 16> *InstructionsProcessed) { + ArrayRef<Instruction *> Chain, + SmallPtrSet<Instruction *, 16> *InstructionsProcessed) { LoadInst *L0 = cast<LoadInst>(Chain[0]); // If the vector has an int element, default to int for the whole load. @@ -838,39 +898,34 @@ bool Vectorizer::vectorizeLoadChain( unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); unsigned VF = VecRegSize / Sz; unsigned ChainSize = Chain.size(); + unsigned Alignment = getAlignment(L0); if (!isPowerOf2_32(Sz) || VF < 2 || ChainSize < 2) { InstructionsProcessed->insert(Chain.begin(), Chain.end()); return false; } - BasicBlock::iterator First, Last; - std::tie(First, Last) = getBoundaryInstrs(Chain); - unsigned StopChain = getVectorizablePrefixEndIdx(Chain, First, Last); - if (StopChain == 0) { - // There exists a side effect instruction, no vectorization possible. + ArrayRef<Instruction *> NewChain = getVectorizablePrefix(Chain); + if (NewChain.empty()) { + // No vectorization possible. InstructionsProcessed->insert(Chain.begin(), Chain.end()); return false; } - if (StopChain == 1) { + if (NewChain.size() == 1) { // Failed after the first instruction. Discard it and try the smaller chain. - InstructionsProcessed->insert(Chain.front()); + InstructionsProcessed->insert(NewChain.front()); return false; } // Update Chain to the valid vectorizable subchain. - Chain = Chain.slice(0, StopChain); + Chain = NewChain; ChainSize = Chain.size(); - // Load size should be 1B, 2B or multiple of 4B. - // TODO: Should size constraint be a target hook? - unsigned SzInBytes = (Sz / 8) * ChainSize; - if (SzInBytes > 2 && SzInBytes % 4 != 0) { - DEBUG(dbgs() << "LSV: Size should be 1B, 2B " - "or multiple of 4B. Splitting.\n"); - if (SzInBytes == 3) - return vectorizeLoadChain(Chain.slice(0, ChainSize - 1), - InstructionsProcessed); + // Check if it's legal to vectorize this chain. If not, split the chain and + // try again. + unsigned EltSzInBytes = Sz / 8; + unsigned SzInBytes = EltSzInBytes * ChainSize; + if (!TTI.isLegalToVectorizeLoadChain(SzInBytes, Alignment, AS)) { auto Chains = splitOddVectorElts(Chain, Sz); return vectorizeLoadChain(Chains.first, InstructionsProcessed) | vectorizeLoadChain(Chains.second, InstructionsProcessed); @@ -884,101 +939,99 @@ bool Vectorizer::vectorizeLoadChain( else VecTy = VectorType::get(LoadTy, Chain.size()); - // If it's more than the max vector size, break it into two pieces. - // TODO: Target hook to control types to split to. - if (ChainSize > VF) { - DEBUG(dbgs() << "LSV: Vector factor is too big. " - "Creating two separate arrays.\n"); - return vectorizeLoadChain(Chain.slice(0, VF), InstructionsProcessed) | - vectorizeLoadChain(Chain.slice(VF), InstructionsProcessed); + // If it's more than the max vector size or the target has a better + // vector factor, break it into two pieces. + unsigned TargetVF = TTI.getLoadVectorFactor(VF, Sz, SzInBytes, VecTy); + if (ChainSize > VF || (VF != TargetVF && TargetVF < ChainSize)) { + DEBUG(dbgs() << "LSV: Chain doesn't match with the vector factor." + " Creating two separate arrays.\n"); + return vectorizeLoadChain(Chain.slice(0, TargetVF), InstructionsProcessed) | + vectorizeLoadChain(Chain.slice(TargetVF), InstructionsProcessed); } // We won't try again to vectorize the elements of the chain, regardless of // whether we succeed below. InstructionsProcessed->insert(Chain.begin(), Chain.end()); - // Check alignment restrictions. - unsigned Alignment = getAlignment(L0); - // If the load is going to be misaligned, don't vectorize it. if (accessIsMisaligned(SzInBytes, AS, Alignment)) { if (L0->getPointerAddressSpace() != 0) return false; - // If we're loading from an object on the stack, we control its alignment, - // so we can cheat and change it! - Value *V = GetUnderlyingObject(L0->getPointerOperand(), DL); - if (AllocaInst *AI = dyn_cast_or_null<AllocaInst>(V)) { - AI->setAlignment(TargetBaseAlign); - Alignment = TargetBaseAlign; - } else { + unsigned NewAlign = getOrEnforceKnownAlignment(L0->getPointerOperand(), + StackAdjustedAlignment, + DL, L0, nullptr, &DT); + if (NewAlign < StackAdjustedAlignment) return false; - } + + Alignment = NewAlign; } DEBUG({ dbgs() << "LSV: Loads to vectorize:\n"; - for (Value *V : Chain) - V->dump(); + for (Instruction *I : Chain) + I->dump(); }); - // Set insert point. + // getVectorizablePrefix already computed getBoundaryInstrs. The value of + // Last may have changed since then, but the value of First won't have. If it + // matters, we could compute getBoundaryInstrs only once and reuse it here. + BasicBlock::iterator First, Last; + std::tie(First, Last) = getBoundaryInstrs(Chain); Builder.SetInsertPoint(&*First); Value *Bitcast = Builder.CreateBitCast(L0->getPointerOperand(), VecTy->getPointerTo(AS)); - + // This cast is safe because Builder.CreateLoad always creates a bona fide + // LoadInst. LoadInst *LI = cast<LoadInst>(Builder.CreateLoad(Bitcast)); propagateMetadata(LI, Chain); LI->setAlignment(Alignment); if (VecLoadTy) { SmallVector<Instruction *, 16> InstrsToErase; - SmallVector<Instruction *, 16> InstrsToReorder; - InstrsToReorder.push_back(cast<Instruction>(Bitcast)); unsigned VecWidth = VecLoadTy->getNumElements(); for (unsigned I = 0, E = Chain.size(); I != E; ++I) { for (auto Use : Chain[I]->users()) { + // All users of vector loads are ExtractElement instructions with + // constant indices, otherwise we would have bailed before now. Instruction *UI = cast<Instruction>(Use); unsigned Idx = cast<ConstantInt>(UI->getOperand(1))->getZExtValue(); unsigned NewIdx = Idx + I * VecWidth; - Value *V = Builder.CreateExtractElement(LI, Builder.getInt32(NewIdx)); - Instruction *Extracted = cast<Instruction>(V); - if (Extracted->getType() != UI->getType()) - Extracted = cast<Instruction>( - Builder.CreateBitCast(Extracted, UI->getType())); + Value *V = Builder.CreateExtractElement(LI, Builder.getInt32(NewIdx), + UI->getName()); + if (V->getType() != UI->getType()) + V = Builder.CreateBitCast(V, UI->getType()); // Replace the old instruction. - UI->replaceAllUsesWith(Extracted); + UI->replaceAllUsesWith(V); InstrsToErase.push_back(UI); } } - for (Instruction *ModUser : InstrsToReorder) - reorder(ModUser); + // Bitcast might not be an Instruction, if the value being loaded is a + // constant. In that case, no need to reorder anything. + if (Instruction *BitcastInst = dyn_cast<Instruction>(Bitcast)) + reorder(BitcastInst); for (auto I : InstrsToErase) I->eraseFromParent(); } else { - SmallVector<Instruction *, 16> InstrsToReorder; - InstrsToReorder.push_back(cast<Instruction>(Bitcast)); - for (unsigned I = 0, E = Chain.size(); I != E; ++I) { - Value *V = Builder.CreateExtractElement(LI, Builder.getInt32(I)); - Instruction *Extracted = cast<Instruction>(V); - Instruction *UI = cast<Instruction>(Chain[I]); - if (Extracted->getType() != UI->getType()) { - Extracted = cast<Instruction>( - Builder.CreateBitOrPointerCast(Extracted, UI->getType())); + Value *CV = Chain[I]; + Value *V = + Builder.CreateExtractElement(LI, Builder.getInt32(I), CV->getName()); + if (V->getType() != CV->getType()) { + V = Builder.CreateBitOrPointerCast(V, CV->getType()); } // Replace the old instruction. - UI->replaceAllUsesWith(Extracted); + CV->replaceAllUsesWith(V); } - for (Instruction *ModUser : InstrsToReorder) - reorder(ModUser); + if (Instruction *BitcastInst = dyn_cast<Instruction>(Bitcast)) + reorder(BitcastInst); } eraseInstructions(Chain); @@ -990,10 +1043,14 @@ bool Vectorizer::vectorizeLoadChain( bool Vectorizer::accessIsMisaligned(unsigned SzInBytes, unsigned AddressSpace, unsigned Alignment) { + if (Alignment % SzInBytes == 0) + return false; + bool Fast = false; - bool Allows = TTI.allowsMisalignedMemoryAccesses(SzInBytes * 8, AddressSpace, + bool Allows = TTI.allowsMisalignedMemoryAccesses(F.getParent()->getContext(), + SzInBytes * 8, AddressSpace, Alignment, &Fast); - // TODO: Remove TargetBaseAlign - return !(Allows && Fast) && (Alignment % SzInBytes) != 0 && - (Alignment % TargetBaseAlign) != 0; + DEBUG(dbgs() << "LSV: Target said misaligned is allowed? " << Allows + << " and fast? " << Fast << "\n";); + return !Allows || !Fast; } diff --git a/contrib/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/contrib/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index ee5733d..dac7032 100644 --- a/contrib/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/contrib/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -80,6 +80,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" +#include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" #include "llvm/IR/Verifier.h" @@ -191,7 +192,7 @@ static cl::opt<bool> EnableIndVarRegisterHeur( cl::desc("Count the induction variable only once when interleaving")); static cl::opt<bool> EnableCondStoresVectorization( - "enable-cond-stores-vec", cl::init(false), cl::Hidden, + "enable-cond-stores-vec", cl::init(true), cl::Hidden, cl::desc("Enable if predication of stores during vectorization.")); static cl::opt<unsigned> MaxNestedScalarReductionIC( @@ -213,6 +214,32 @@ static cl::opt<unsigned> PragmaVectorizeSCEVCheckThreshold( cl::desc("The maximum number of SCEV checks allowed with a " "vectorize(enable) pragma")); +/// Create an analysis remark that explains why vectorization failed +/// +/// \p PassName is the name of the pass (e.g. can be AlwaysPrint). \p +/// RemarkName is the identifier for the remark. If \p I is passed it is an +/// instruction that prevents vectorization. Otherwise \p TheLoop is used for +/// the location of the remark. \return the remark object that can be +/// streamed to. +static OptimizationRemarkAnalysis +createMissedAnalysis(const char *PassName, StringRef RemarkName, Loop *TheLoop, + Instruction *I = nullptr) { + Value *CodeRegion = TheLoop->getHeader(); + DebugLoc DL = TheLoop->getStartLoc(); + + if (I) { + CodeRegion = I->getParent(); + // If there is no debug location attached to the instruction, revert back to + // using the loop's. + if (I->getDebugLoc()) + DL = I->getDebugLoc(); + } + + OptimizationRemarkAnalysis R(PassName, RemarkName, DL, CodeRegion); + R << "loop not vectorized: "; + return R; +} + namespace { // Forward declarations. @@ -221,70 +248,13 @@ class LoopVectorizationLegality; class LoopVectorizationCostModel; class LoopVectorizationRequirements; -// A traits type that is intended to be used in graph algorithms. The graph it -// models starts at the loop header, and traverses the BasicBlocks that are in -// the loop body, but not the loop header. Since the loop header is skipped, -// the back edges are excluded. -struct LoopBodyTraits { - using NodeRef = std::pair<const Loop *, BasicBlock *>; - - // This wraps a const Loop * into the iterator, so we know which edges to - // filter out. - class WrappedSuccIterator - : public iterator_adaptor_base< - WrappedSuccIterator, succ_iterator, - typename std::iterator_traits<succ_iterator>::iterator_category, - NodeRef, std::ptrdiff_t, NodeRef *, NodeRef> { - using BaseT = iterator_adaptor_base< - WrappedSuccIterator, succ_iterator, - typename std::iterator_traits<succ_iterator>::iterator_category, - NodeRef, std::ptrdiff_t, NodeRef *, NodeRef>; - - const Loop *L; - - public: - WrappedSuccIterator(succ_iterator Begin, const Loop *L) - : BaseT(Begin), L(L) {} - - NodeRef operator*() const { return {L, *I}; } - }; - - struct LoopBodyFilter { - bool operator()(NodeRef N) const { - const Loop *L = N.first; - return N.second != L->getHeader() && L->contains(N.second); - } - }; - - using ChildIteratorType = - filter_iterator<WrappedSuccIterator, LoopBodyFilter>; - - static NodeRef getEntryNode(const Loop &G) { return {&G, G.getHeader()}; } - - static ChildIteratorType child_begin(NodeRef Node) { - return make_filter_range(make_range<WrappedSuccIterator>( - {succ_begin(Node.second), Node.first}, - {succ_end(Node.second), Node.first}), - LoopBodyFilter{}) - .begin(); - } - - static ChildIteratorType child_end(NodeRef Node) { - return make_filter_range(make_range<WrappedSuccIterator>( - {succ_begin(Node.second), Node.first}, - {succ_end(Node.second), Node.first}), - LoopBodyFilter{}) - .end(); - } -}; - /// Returns true if the given loop body has a cycle, excluding the loop /// itself. static bool hasCyclesInLoopBody(const Loop &L) { if (!L.empty()) return true; - for (const auto SCC : + for (const auto &SCC : make_range(scc_iterator<Loop, LoopBodyTraits>::begin(L), scc_iterator<Loop, LoopBodyTraits>::end(L))) { if (SCC.size() > 1) { @@ -346,6 +316,41 @@ static GetElementPtrInst *getGEPInstruction(Value *Ptr) { return nullptr; } +/// A helper function that returns the pointer operand of a load or store +/// instruction. +static Value *getPointerOperand(Value *I) { + if (auto *LI = dyn_cast<LoadInst>(I)) + return LI->getPointerOperand(); + if (auto *SI = dyn_cast<StoreInst>(I)) + return SI->getPointerOperand(); + return nullptr; +} + +/// A helper function that returns true if the given type is irregular. The +/// type is irregular if its allocated size doesn't equal the store size of an +/// element of the corresponding vector type at the given vectorization factor. +static bool hasIrregularType(Type *Ty, const DataLayout &DL, unsigned VF) { + + // Determine if an array of VF elements of type Ty is "bitcast compatible" + // with a <VF x Ty> vector. + if (VF > 1) { + auto *VectorTy = VectorType::get(Ty, VF); + return VF * DL.getTypeAllocSize(Ty) != DL.getTypeStoreSize(VectorTy); + } + + // If the vectorization factor is one, we just check if an array of type Ty + // requires padding between elements. + return DL.getTypeAllocSizeInBits(Ty) != DL.getTypeSizeInBits(Ty); +} + +/// A helper function that returns the reciprocal of the block probability of +/// predicated blocks. If we return X, we are assuming the predicated block +/// will execute once for for every X iterations of the loop header. +/// +/// TODO: We should use actual block probability here, if available. Currently, +/// we always assume predicated blocks have a 50% chance of executing. +static unsigned getReciprocalPredBlockProb() { return 2; } + /// InnerLoopVectorizer vectorizes loops which contain only one basic /// block to a specified vectorization factor (VF). /// This class performs the widening of scalars into vectors, or multiple @@ -366,29 +371,21 @@ public: LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, AssumptionCache *AC, - unsigned VecWidth, unsigned UnrollFactor) + OptimizationRemarkEmitter *ORE, unsigned VecWidth, + unsigned UnrollFactor, LoopVectorizationLegality *LVL, + LoopVectorizationCostModel *CM) : OrigLoop(OrigLoop), PSE(PSE), LI(LI), DT(DT), TLI(TLI), TTI(TTI), - AC(AC), VF(VecWidth), UF(UnrollFactor), + AC(AC), ORE(ORE), VF(VecWidth), UF(UnrollFactor), Builder(PSE.getSE()->getContext()), Induction(nullptr), - OldInduction(nullptr), WidenMap(UnrollFactor), TripCount(nullptr), - VectorTripCount(nullptr), Legal(nullptr), AddedSafetyChecks(false) {} + OldInduction(nullptr), VectorLoopValueMap(UnrollFactor, VecWidth), + TripCount(nullptr), VectorTripCount(nullptr), Legal(LVL), Cost(CM), + AddedSafetyChecks(false) {} // Perform the actual loop widening (vectorization). - // MinimumBitWidths maps scalar integer values to the smallest bitwidth they - // can be validly truncated to. The cost model has assumed this truncation - // will happen when vectorizing. VecValuesToIgnore contains scalar values - // that the cost model has chosen to ignore because they will not be - // vectorized. - void vectorize(LoopVectorizationLegality *L, - const MapVector<Instruction *, uint64_t> &MinimumBitWidths, - SmallPtrSetImpl<const Value *> &VecValuesToIgnore) { - MinBWs = &MinimumBitWidths; - ValuesNotWidened = &VecValuesToIgnore; - Legal = L; + void vectorize() { // Create a new empty loop. Unlink the old loop and connect the new one. createEmptyLoop(); // Widen each instruction in the old loop to a new one in the new loop. - // Use the Legality module to find the induction and reduction variables. vectorizeLoop(); } @@ -400,11 +397,18 @@ public: protected: /// A small list of PHINodes. typedef SmallVector<PHINode *, 4> PhiVector; - /// When we unroll loops we have multiple vector values for each scalar. - /// This data structure holds the unrolled and vectorized values that - /// originated from one scalar instruction. + + /// A type for vectorized values in the new loop. Each value from the + /// original loop, when vectorized, is represented by UF vector values in the + /// new unrolled loop, where UF is the unroll factor. typedef SmallVector<Value *, 2> VectorParts; + /// A type for scalarized values in the new loop. Each value from the + /// original loop, when scalarized, is represented by UF x VF scalar values + /// in the new unrolled loop, where UF is the unroll factor and VF is the + /// vectorization factor. + typedef SmallVector<SmallVector<Value *, 4>, 2> ScalarParts; + // When we if-convert we need to create edge masks. We have to cache values // so that we don't end up with exponential recursion/IR. typedef DenseMap<std::pair<BasicBlock *, BasicBlock *>, VectorParts> @@ -434,7 +438,20 @@ protected: /// See PR14725. void fixLCSSAPHIs(); - /// Shrinks vector element sizes based on information in "MinBWs". + /// Iteratively sink the scalarized operands of a predicated instruction into + /// the block that was created for it. + void sinkScalarOperands(Instruction *PredInst); + + /// Predicate conditional instructions that require predication on their + /// respective conditions. + void predicateInstructions(); + + /// Collect the instructions from the original loop that would be trivially + /// dead in the vectorized loop if generated. + void collectTriviallyDeadInstructions(); + + /// Shrinks vector element sizes to the smallest bitwidth they can be legally + /// represented as. void truncateToMinimalBitwidths(); /// A helper function that computes the predicate of the block BB, assuming @@ -451,19 +468,19 @@ protected: /// Vectorize a single PHINode in a block. This method handles the induction /// variable canonicalization. It supports both VF = 1 for unrolled loops and /// arbitrary length vectors. - void widenPHIInstruction(Instruction *PN, VectorParts &Entry, unsigned UF, - unsigned VF, PhiVector *PV); + void widenPHIInstruction(Instruction *PN, unsigned UF, unsigned VF, + PhiVector *PV); /// Insert the new loop to the loop hierarchy and pass manager /// and update the analysis passes. void updateAnalysis(); /// This instruction is un-vectorizable. Implement it as a sequence - /// of scalars. If \p IfPredicateStore is true we need to 'hide' each + /// of scalars. If \p IfPredicateInstr is true we need to 'hide' each /// scalarized instruction behind an if block predicated on the control /// dependence of the instruction. virtual void scalarizeInstruction(Instruction *Instr, - bool IfPredicateStore = false); + bool IfPredicateInstr = false); /// Vectorize Load and Store instructions, virtual void vectorizeMemoryInstruction(Instruction *Instr); @@ -477,7 +494,10 @@ protected: /// This function adds (StartIdx, StartIdx + Step, StartIdx + 2*Step, ...) /// to each vector element of Val. The sequence starts at StartIndex. - virtual Value *getStepVector(Value *Val, int StartIdx, Value *Step); + /// \p Opcode is relevant for FP induction variable. + virtual Value *getStepVector(Value *Val, int StartIdx, Value *Step, + Instruction::BinaryOps Opcode = + Instruction::BinaryOpsEnd); /// Compute scalar induction steps. \p ScalarIV is the scalar induction /// variable on which to base the steps, \p Step is the size of the step, and @@ -488,23 +508,39 @@ protected: /// Create a vector induction phi node based on an existing scalar one. This /// currently only works for integer induction variables with a constant - /// step. If \p TruncType is non-null, instead of widening the original IV, - /// we widen a version of the IV truncated to \p TruncType. + /// step. \p EntryVal is the value from the original loop that maps to the + /// vector phi node. If \p EntryVal is a truncate instruction, instead of + /// widening the original IV, we widen a version of the IV truncated to \p + /// EntryVal's type. void createVectorIntInductionPHI(const InductionDescriptor &II, - VectorParts &Entry, IntegerType *TruncType); + Instruction *EntryVal); /// Widen an integer induction variable \p IV. If \p Trunc is provided, the - /// induction variable will first be truncated to the corresponding type. The - /// widened values are placed in \p Entry. - void widenIntInduction(PHINode *IV, VectorParts &Entry, - TruncInst *Trunc = nullptr); - - /// When we go over instructions in the basic block we rely on previous - /// values within the current basic block or on loop invariant values. - /// When we widen (vectorize) values we place them in the map. If the values - /// are not within the map, they have to be loop invariant, so we simply - /// broadcast them into a vector. - VectorParts &getVectorValue(Value *V); + /// induction variable will first be truncated to the corresponding type. + void widenIntInduction(PHINode *IV, TruncInst *Trunc = nullptr); + + /// Returns true if an instruction \p I should be scalarized instead of + /// vectorized for the chosen vectorization factor. + bool shouldScalarizeInstruction(Instruction *I) const; + + /// Returns true if we should generate a scalar version of \p IV. + bool needsScalarInduction(Instruction *IV) const; + + /// Return a constant reference to the VectorParts corresponding to \p V from + /// the original loop. If the value has already been vectorized, the + /// corresponding vector entry in VectorLoopValueMap is returned. If, + /// however, the value has a scalar entry in VectorLoopValueMap, we construct + /// new vector values on-demand by inserting the scalar values into vectors + /// with an insertelement sequence. If the value has been neither vectorized + /// nor scalarized, it must be loop invariant, so we simply broadcast the + /// value into vectors. + const VectorParts &getVectorValue(Value *V); + + /// Return a value in the new loop corresponding to \p V from the original + /// loop at unroll index \p Part and vector index \p Lane. If the value has + /// been vectorized but not scalarized, the necessary extractelement + /// instruction will be generated. + Value *getScalarValue(Value *V, unsigned Part, unsigned Lane); /// Try to vectorize the interleaved access group that \p Instr belongs to. void vectorizeInterleaveGroup(Instruction *Instr); @@ -547,44 +583,112 @@ protected: /// vector of instructions. void addMetadata(ArrayRef<Value *> To, Instruction *From); - /// This is a helper class that holds the vectorizer state. It maps scalar - /// instructions to vector instructions. When the code is 'unrolled' then - /// then a single scalar value is mapped to multiple vector parts. The parts - /// are stored in the VectorPart type. + /// This is a helper class for maintaining vectorization state. It's used for + /// mapping values from the original loop to their corresponding values in + /// the new loop. Two mappings are maintained: one for vectorized values and + /// one for scalarized values. Vectorized values are represented with UF + /// vector values in the new loop, and scalarized values are represented with + /// UF x VF scalar values in the new loop. UF and VF are the unroll and + /// vectorization factors, respectively. + /// + /// Entries can be added to either map with initVector and initScalar, which + /// initialize and return a constant reference to the new entry. If a + /// non-constant reference to a vector entry is required, getVector can be + /// used to retrieve a mutable entry. We currently directly modify the mapped + /// values during "fix-up" operations that occur once the first phase of + /// widening is complete. These operations include type truncation and the + /// second phase of recurrence widening. + /// + /// Otherwise, entries from either map should be accessed using the + /// getVectorValue or getScalarValue functions from InnerLoopVectorizer. + /// getVectorValue and getScalarValue coordinate to generate a vector or + /// scalar value on-demand if one is not yet available. When vectorizing a + /// loop, we visit the definition of an instruction before its uses. When + /// visiting the definition, we either vectorize or scalarize the + /// instruction, creating an entry for it in the corresponding map. (In some + /// cases, such as induction variables, we will create both vector and scalar + /// entries.) Then, as we encounter uses of the definition, we derive values + /// for each scalar or vector use unless such a value is already available. + /// For example, if we scalarize a definition and one of its uses is vector, + /// we build the required vector on-demand with an insertelement sequence + /// when visiting the use. Otherwise, if the use is scalar, we can use the + /// existing scalar definition. struct ValueMap { - /// C'tor. UnrollFactor controls the number of vectors ('parts') that - /// are mapped. - ValueMap(unsigned UnrollFactor) : UF(UnrollFactor) {} - - /// \return True if 'Key' is saved in the Value Map. - bool has(Value *Key) const { return MapStorage.count(Key); } - - /// Initializes a new entry in the map. Sets all of the vector parts to the - /// save value in 'Val'. - /// \return A reference to a vector with splat values. - VectorParts &splat(Value *Key, Value *Val) { - VectorParts &Entry = MapStorage[Key]; - Entry.assign(UF, Val); - return Entry; + + /// Construct an empty map with the given unroll and vectorization factors. + ValueMap(unsigned UnrollFactor, unsigned VecWidth) + : UF(UnrollFactor), VF(VecWidth) { + // The unroll and vectorization factors are only used in asserts builds + // to verify map entries are sized appropriately. + (void)UF; + (void)VF; } - ///\return A reference to the value that is stored at 'Key'. - VectorParts &get(Value *Key) { - VectorParts &Entry = MapStorage[Key]; - if (Entry.empty()) - Entry.resize(UF); - assert(Entry.size() == UF); - return Entry; + /// \return True if the map has a vector entry for \p Key. + bool hasVector(Value *Key) const { return VectorMapStorage.count(Key); } + + /// \return True if the map has a scalar entry for \p Key. + bool hasScalar(Value *Key) const { return ScalarMapStorage.count(Key); } + + /// \brief Map \p Key to the given VectorParts \p Entry, and return a + /// constant reference to the new vector map entry. The given key should + /// not already be in the map, and the given VectorParts should be + /// correctly sized for the current unroll factor. + const VectorParts &initVector(Value *Key, const VectorParts &Entry) { + assert(!hasVector(Key) && "Vector entry already initialized"); + assert(Entry.size() == UF && "VectorParts has wrong dimensions"); + VectorMapStorage[Key] = Entry; + return VectorMapStorage[Key]; } + /// \brief Map \p Key to the given ScalarParts \p Entry, and return a + /// constant reference to the new scalar map entry. The given key should + /// not already be in the map, and the given ScalarParts should be + /// correctly sized for the current unroll and vectorization factors. + const ScalarParts &initScalar(Value *Key, const ScalarParts &Entry) { + assert(!hasScalar(Key) && "Scalar entry already initialized"); + assert(Entry.size() == UF && + all_of(make_range(Entry.begin(), Entry.end()), + [&](const SmallVectorImpl<Value *> &Values) -> bool { + return Values.size() == VF; + }) && + "ScalarParts has wrong dimensions"); + ScalarMapStorage[Key] = Entry; + return ScalarMapStorage[Key]; + } + + /// \return A reference to the vector map entry corresponding to \p Key. + /// The key should already be in the map. This function should only be used + /// when it's necessary to update values that have already been vectorized. + /// This is the case for "fix-up" operations including type truncation and + /// the second phase of recurrence vectorization. If a non-const reference + /// isn't required, getVectorValue should be used instead. + VectorParts &getVector(Value *Key) { + assert(hasVector(Key) && "Vector entry not initialized"); + return VectorMapStorage.find(Key)->second; + } + + /// Retrieve an entry from the vector or scalar maps. The preferred way to + /// access an existing mapped entry is with getVectorValue or + /// getScalarValue from InnerLoopVectorizer. Until those functions can be + /// moved inside ValueMap, we have to declare them as friends. + friend const VectorParts &InnerLoopVectorizer::getVectorValue(Value *V); + friend Value *InnerLoopVectorizer::getScalarValue(Value *V, unsigned Part, + unsigned Lane); + private: - /// The unroll factor. Each entry in the map stores this number of vector - /// elements. + /// The unroll factor. Each entry in the vector map contains UF vector + /// values. unsigned UF; - /// Map storage. We use std::map and not DenseMap because insertions to a - /// dense map invalidates its iterators. - std::map<Value *, VectorParts> MapStorage; + /// The vectorization factor. Each entry in the scalar map contains UF x VF + /// scalar values. + unsigned VF; + + /// The vector and scalar map storage. We use std::map and not DenseMap + /// because insertions to DenseMap invalidate its iterators. + std::map<Value *, VectorParts> VectorMapStorage; + std::map<Value *, ScalarParts> ScalarMapStorage; }; /// The original loop. @@ -605,6 +709,8 @@ protected: const TargetTransformInfo *TTI; /// Assumption Cache. AssumptionCache *AC; + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter *ORE; /// \brief LoopVersioning. It's only set up (non-null) if memchecks were /// used. @@ -646,41 +752,42 @@ protected: PHINode *Induction; /// The induction variable of the old basic block. PHINode *OldInduction; - /// Maps scalars to widened vectors. - ValueMap WidenMap; - - /// A map of induction variables from the original loop to their - /// corresponding VF * UF scalarized values in the vectorized loop. The - /// purpose of ScalarIVMap is similar to that of WidenMap. Whereas WidenMap - /// maps original loop values to their vector versions in the new loop, - /// ScalarIVMap maps induction variables from the original loop that are not - /// vectorized to their scalar equivalents in the vector loop. Maintaining a - /// separate map for scalarized induction variables allows us to avoid - /// unnecessary scalar-to-vector-to-scalar conversions. - DenseMap<Value *, SmallVector<Value *, 8>> ScalarIVMap; + + /// Maps values from the original loop to their corresponding values in the + /// vectorized loop. A key value can map to either vector values, scalar + /// values or both kinds of values, depending on whether the key was + /// vectorized and scalarized. + ValueMap VectorLoopValueMap; /// Store instructions that should be predicated, as a pair /// <StoreInst, Predicate> - SmallVector<std::pair<StoreInst *, Value *>, 4> PredicatedStores; + SmallVector<std::pair<Instruction *, Value *>, 4> PredicatedInstructions; EdgeMaskCache MaskCache; /// Trip count of the original loop. Value *TripCount; /// Trip count of the widened loop (TripCount - TripCount % (VF*UF)) Value *VectorTripCount; - /// Map of scalar integer values to the smallest bitwidth they can be legally - /// represented as. The vector equivalents of these values should be truncated - /// to this type. - const MapVector<Instruction *, uint64_t> *MinBWs; - - /// A set of values that should not be widened. This is taken from - /// VecValuesToIgnore in the cost model. - SmallPtrSetImpl<const Value *> *ValuesNotWidened; - + /// The legality analysis. LoopVectorizationLegality *Legal; + /// The profitablity analysis. + LoopVectorizationCostModel *Cost; + // Record whether runtime checks are added. bool AddedSafetyChecks; + + // Holds instructions from the original loop whose counterparts in the + // vectorized loop would be trivially dead if generated. For example, + // original induction update instructions can become dead because we + // separately emit induction "steps" when generating code for the new loop. + // Similarly, we create a new latch condition when setting up the structure + // of the new loop, so the old one can become dead. + SmallPtrSet<Instruction *, 4> DeadInstructions; + + // Holds the end values for each induction variable. We save the end values + // so we can later fix-up the external users of the induction variables. + DenseMap<PHINode *, Value *> IVEndValues; }; class InnerLoopUnroller : public InnerLoopVectorizer { @@ -689,16 +796,20 @@ public: LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, AssumptionCache *AC, - unsigned UnrollFactor) - : InnerLoopVectorizer(OrigLoop, PSE, LI, DT, TLI, TTI, AC, 1, - UnrollFactor) {} + OptimizationRemarkEmitter *ORE, unsigned UnrollFactor, + LoopVectorizationLegality *LVL, + LoopVectorizationCostModel *CM) + : InnerLoopVectorizer(OrigLoop, PSE, LI, DT, TLI, TTI, AC, ORE, 1, + UnrollFactor, LVL, CM) {} private: void scalarizeInstruction(Instruction *Instr, - bool IfPredicateStore = false) override; + bool IfPredicateInstr = false) override; void vectorizeMemoryInstruction(Instruction *Instr) override; Value *getBroadcastInstrs(Value *V) override; - Value *getStepVector(Value *Val, int StartIdx, Value *Step) override; + Value *getStepVector(Value *Val, int StartIdx, Value *Step, + Instruction::BinaryOps Opcode = + Instruction::BinaryOpsEnd) override; Value *reverseVector(Value *Vec) override; }; @@ -1149,12 +1260,13 @@ public: FK_Enabled = 1, ///< Forcing enabled. }; - LoopVectorizeHints(const Loop *L, bool DisableInterleaving) + LoopVectorizeHints(const Loop *L, bool DisableInterleaving, + OptimizationRemarkEmitter &ORE) : Width("vectorize.width", VectorizerParams::VectorizationFactor, HK_WIDTH), Interleave("interleave.count", DisableInterleaving, HK_UNROLL), Force("vectorize.enable", FK_Undefined, HK_FORCE), - PotentiallyUnsafe(false), TheLoop(L) { + PotentiallyUnsafe(false), TheLoop(L), ORE(ORE) { // Populate values with existing loop metadata. getHintsFromMetadata(); @@ -1176,17 +1288,13 @@ public: bool allowVectorization(Function *F, Loop *L, bool AlwaysVectorize) const { if (getForce() == LoopVectorizeHints::FK_Disabled) { DEBUG(dbgs() << "LV: Not vectorizing: #pragma vectorize disable.\n"); - emitOptimizationRemarkAnalysis(F->getContext(), - vectorizeAnalysisPassName(), *F, - L->getStartLoc(), emitRemark()); + emitRemarkWithHints(); return false; } if (!AlwaysVectorize && getForce() != LoopVectorizeHints::FK_Enabled) { DEBUG(dbgs() << "LV: Not vectorizing: No #pragma vectorize enable.\n"); - emitOptimizationRemarkAnalysis(F->getContext(), - vectorizeAnalysisPassName(), *F, - L->getStartLoc(), emitRemark()); + emitRemarkWithHints(); return false; } @@ -1197,11 +1305,12 @@ public: // FIXME: Add interleave.disable metadata. This will allow // vectorize.disable to be used without disabling the pass and errors // to differentiate between disabled vectorization and a width of 1. - emitOptimizationRemarkAnalysis( - F->getContext(), vectorizeAnalysisPassName(), *F, L->getStartLoc(), - "loop not vectorized: vectorization and interleaving are explicitly " - "disabled, or vectorize width and interleave count are both set to " - "1"); + ORE.emit(OptimizationRemarkAnalysis(vectorizeAnalysisPassName(), + "AllDisabled", L->getStartLoc(), + L->getHeader()) + << "loop not vectorized: vectorization and interleaving are " + "explicitly disabled, or vectorize width and interleave " + "count are both set to 1"); return false; } @@ -1209,23 +1318,27 @@ public: } /// Dumps all the hint information. - std::string emitRemark() const { - VectorizationReport R; + void emitRemarkWithHints() const { + using namespace ore; if (Force.Value == LoopVectorizeHints::FK_Disabled) - R << "vectorization is explicitly disabled"; + ORE.emit(OptimizationRemarkMissed(LV_NAME, "MissedExplicitlyDisabled", + TheLoop->getStartLoc(), + TheLoop->getHeader()) + << "loop not vectorized: vectorization is explicitly disabled"); else { - R << "use -Rpass-analysis=loop-vectorize for more info"; + OptimizationRemarkMissed R(LV_NAME, "MissedDetails", + TheLoop->getStartLoc(), TheLoop->getHeader()); + R << "loop not vectorized"; if (Force.Value == LoopVectorizeHints::FK_Enabled) { - R << " (Force=true"; + R << " (Force=" << NV("Force", true); if (Width.Value != 0) - R << ", Vector Width=" << Width.Value; + R << ", Vector Width=" << NV("VectorWidth", Width.Value); if (Interleave.Value != 0) - R << ", Interleave Count=" << Interleave.Value; + R << ", Interleave Count=" << NV("InterleaveCount", Interleave.Value); R << ")"; } + ORE.emit(R); } - - return R.str(); } unsigned getWidth() const { return Width.Value; } @@ -1241,7 +1354,7 @@ public: return LV_NAME; if (getForce() == LoopVectorizeHints::FK_Undefined && getWidth() == 0) return LV_NAME; - return DiagnosticInfoOptimizationRemarkAnalysis::AlwaysPrint; + return OptimizationRemarkAnalysis::AlwaysPrint; } bool allowReordering() const { @@ -1379,19 +1492,23 @@ private: /// The loop these hints belong to. const Loop *TheLoop; + + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter &ORE; }; -static void emitAnalysisDiag(const Function *TheFunction, const Loop *TheLoop, +static void emitAnalysisDiag(const Loop *TheLoop, const LoopVectorizeHints &Hints, + OptimizationRemarkEmitter &ORE, const LoopAccessReport &Message) { const char *Name = Hints.vectorizeAnalysisPassName(); - LoopAccessReport::emitAnalysis(Message, TheFunction, TheLoop, Name); + LoopAccessReport::emitAnalysis(Message, TheLoop, Name, ORE); } static void emitMissedWarning(Function *F, Loop *L, - const LoopVectorizeHints &LH) { - emitOptimizationRemarkMissed(F->getContext(), LV_NAME, *F, L->getStartLoc(), - LH.emitRemark()); + const LoopVectorizeHints &LH, + OptimizationRemarkEmitter *ORE) { + LH.emitRemarkWithHints(); if (LH.getForce() == LoopVectorizeHints::FK_Enabled) { if (LH.getWidth() != 1) @@ -1425,12 +1542,12 @@ public: TargetLibraryInfo *TLI, AliasAnalysis *AA, Function *F, const TargetTransformInfo *TTI, std::function<const LoopAccessInfo &(Loop &)> *GetLAA, LoopInfo *LI, - LoopVectorizationRequirements *R, LoopVectorizeHints *H) - : NumPredStores(0), TheLoop(L), PSE(PSE), TLI(TLI), TheFunction(F), - TTI(TTI), DT(DT), GetLAA(GetLAA), LAI(nullptr), - InterleaveInfo(PSE, L, DT, LI), Induction(nullptr), - WidestIndTy(nullptr), HasFunNoNaNAttr(false), Requirements(R), - Hints(H) {} + OptimizationRemarkEmitter *ORE, LoopVectorizationRequirements *R, + LoopVectorizeHints *H) + : NumPredStores(0), TheLoop(L), PSE(PSE), TLI(TLI), TTI(TTI), DT(DT), + GetLAA(GetLAA), LAI(nullptr), ORE(ORE), InterleaveInfo(PSE, L, DT, LI), + Induction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), + Requirements(R), Hints(H) {} /// ReductionList contains the reduction descriptors for all /// of the reductions that were found in the loop. @@ -1490,9 +1607,12 @@ public: /// Returns true if the value V is uniform within the loop. bool isUniform(Value *V); - /// Returns true if this instruction will remain scalar after vectorization. + /// Returns true if \p I is known to be uniform after vectorization. bool isUniformAfterVectorization(Instruction *I) { return Uniforms.count(I); } + /// Returns true if \p I is known to be scalar after vectorization. + bool isScalarAfterVectorization(Instruction *I) { return Scalars.count(I); } + /// Returns the information that we collected about runtime memory check. const RuntimePointerChecking *getRuntimePointerChecking() const { return LAI->getRuntimePointerChecking(); @@ -1545,6 +1665,17 @@ public: bool isLegalMaskedGather(Type *DataType) { return TTI->isLegalMaskedGather(DataType); } + /// Returns true if the target machine can represent \p V as a masked gather + /// or scatter operation. + bool isLegalGatherOrScatter(Value *V) { + auto *LI = dyn_cast<LoadInst>(V); + auto *SI = dyn_cast<StoreInst>(V); + if (!LI && !SI) + return false; + auto *Ptr = getPointerOperand(V); + auto *Ty = cast<PointerType>(Ptr->getType())->getElementType(); + return (LI && isLegalMaskedGather(Ty)) || (SI && isLegalMaskedScatter(Ty)); + } /// Returns true if vector representation of the instruction \p I /// requires mask. @@ -1553,6 +1684,21 @@ public: unsigned getNumLoads() const { return LAI->getNumLoads(); } unsigned getNumPredStores() const { return NumPredStores; } + /// Returns true if \p I is an instruction that will be scalarized with + /// predication. Such instructions include conditional stores and + /// instructions that may divide by zero. + bool isScalarWithPredication(Instruction *I); + + /// Returns true if \p I is a memory instruction that has a consecutive or + /// consecutive-like pointer operand. Consecutive-like pointers are pointers + /// that are treated like consecutive pointers during vectorization. The + /// pointer operands of interleaved accesses are an example. + bool hasConsecutiveLikePtrOperand(Instruction *I); + + /// Returns true if \p I is a memory instruction that must be scalarized + /// during vectorization. + bool memoryInstructionMustBeScalarized(Instruction *I, unsigned VF = 1); + private: /// Check if a single basic block loop is vectorizable. /// At this point we know that this is a loop with a constant trip count @@ -1569,9 +1715,24 @@ private: /// transformation. bool canVectorizeWithIfConvert(); - /// Collect the variables that need to stay uniform after vectorization. + /// Collect the instructions that are uniform after vectorization. An + /// instruction is uniform if we represent it with a single scalar value in + /// the vectorized loop corresponding to each vector iteration. Examples of + /// uniform instructions include pointer operands of consecutive or + /// interleaved memory accesses. Note that although uniformity implies an + /// instruction will be scalar, the reverse is not true. In general, a + /// scalarized instruction will be represented by VF scalar values in the + /// vectorized loop, each corresponding to an iteration of the original + /// scalar loop. void collectLoopUniforms(); + /// Collect the instructions that are scalar after vectorization. An + /// instruction is scalar if it is known to be uniform or will be scalarized + /// during vectorization. Non-uniform scalarized instructions will be + /// represented by VF values in the vectorized loop, each corresponding to an + /// iteration of the original scalar loop. + void collectLoopScalars(); + /// Return true if all of the instructions in the block can be speculatively /// executed. \p SafePtrs is a list of addresses that are known to be legal /// and we know that we can read from them without segfault. @@ -1588,7 +1749,19 @@ private: /// VectorizationReport because the << operator of VectorizationReport returns /// LoopAccessReport. void emitAnalysis(const LoopAccessReport &Message) const { - emitAnalysisDiag(TheFunction, TheLoop, *Hints, Message); + emitAnalysisDiag(TheLoop, *Hints, *ORE, Message); + } + + /// Create an analysis remark that explains why vectorization failed + /// + /// \p RemarkName is the identifier for the remark. If \p I is passed it is + /// an instruction that prevents vectorization. Otherwise the loop is used + /// for the location of the remark. \return the remark object that can be + /// streamed to. + OptimizationRemarkAnalysis + createMissedAnalysis(StringRef RemarkName, Instruction *I = nullptr) const { + return ::createMissedAnalysis(Hints->vectorizeAnalysisPassName(), + RemarkName, TheLoop, I); } /// \brief If an access has a symbolic strides, this maps the pointer value to @@ -1613,8 +1786,6 @@ private: PredicatedScalarEvolution &PSE; /// Target Library Info. TargetLibraryInfo *TLI; - /// Parent function - Function *TheFunction; /// Target Transform Info const TargetTransformInfo *TTI; /// Dominator Tree. @@ -1624,6 +1795,8 @@ private: // And the loop-accesses info corresponding to this loop. This pointer is // null until canVectorizeMemory sets it up. const LoopAccessInfo *LAI; + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter *ORE; /// The interleave access information contains groups of interleaved accesses /// with the same stride and close to each other. @@ -1648,10 +1821,13 @@ private: /// Allowed outside users. This holds the induction and reduction /// vars which can be accessed from outside the loop. SmallPtrSet<Value *, 4> AllowedExit; - /// This set holds the variables which are known to be uniform after - /// vectorization. + + /// Holds the instructions known to be uniform after vectorization. SmallPtrSet<Instruction *, 4> Uniforms; + /// Holds the instructions known to be scalar after vectorization. + SmallPtrSet<Instruction *, 4> Scalars; + /// Can we assume the absence of NaNs. bool HasFunNoNaNAttr; @@ -1679,10 +1855,11 @@ public: LoopInfo *LI, LoopVectorizationLegality *Legal, const TargetTransformInfo &TTI, const TargetLibraryInfo *TLI, DemandedBits *DB, - AssumptionCache *AC, const Function *F, + AssumptionCache *AC, + OptimizationRemarkEmitter *ORE, const Function *F, const LoopVectorizeHints *Hints) : TheLoop(L), PSE(PSE), LI(LI), Legal(Legal), TTI(TTI), TLI(TLI), DB(DB), - AC(AC), TheFunction(F), Hints(Hints) {} + AC(AC), ORE(ORE), TheFunction(F), Hints(Hints) {} /// Information about vectorization costs struct VectorizationFactor { @@ -1707,13 +1884,6 @@ public: unsigned selectInterleaveCount(bool OptForSize, unsigned VF, unsigned LoopCost); - /// \return The most profitable unroll factor. - /// This method finds the best unroll-factor based on register pressure and - /// other parameters. VF and LoopCost are the selected vectorization factor - /// and the cost of the selected VF. - unsigned computeInterleaveCount(bool OptForSize, unsigned VF, - unsigned LoopCost); - /// \brief A struct that represents some properties of the register usage /// of a loop. struct RegisterUsage { @@ -1732,6 +1902,29 @@ public: /// Collect values we want to ignore in the cost model. void collectValuesToIgnore(); + /// \returns The smallest bitwidth each instruction can be represented with. + /// The vector equivalents of these instructions should be truncated to this + /// type. + const MapVector<Instruction *, uint64_t> &getMinimalBitwidths() const { + return MinBWs; + } + + /// \returns True if it is more profitable to scalarize instruction \p I for + /// vectorization factor \p VF. + bool isProfitableToScalarize(Instruction *I, unsigned VF) const { + auto Scalars = InstsToScalarize.find(VF); + assert(Scalars != InstsToScalarize.end() && + "VF not yet analyzed for scalarization profitability"); + return Scalars->second.count(I); + } + + /// \returns True if instruction \p I can be truncated to a smaller bitwidth + /// for vectorization factor \p VF. + bool canTruncateToMinimalBitwidth(Instruction *I, unsigned VF) const { + return VF > 1 && MinBWs.count(I) && !isProfitableToScalarize(I, VF) && + !Legal->isScalarAfterVectorization(I); + } + private: /// The vectorization cost is a combination of the cost itself and a boolean /// indicating whether any of the contributing operations will actually @@ -1760,20 +1953,44 @@ private: /// as a vector operation. bool isConsecutiveLoadOrStore(Instruction *I); - /// Report an analysis message to assist the user in diagnosing loops that are - /// not vectorized. These are handled as LoopAccessReport rather than - /// VectorizationReport because the << operator of VectorizationReport returns - /// LoopAccessReport. - void emitAnalysis(const LoopAccessReport &Message) const { - emitAnalysisDiag(TheFunction, TheLoop, *Hints, Message); + /// Create an analysis remark that explains why vectorization failed + /// + /// \p RemarkName is the identifier for the remark. \return the remark object + /// that can be streamed to. + OptimizationRemarkAnalysis createMissedAnalysis(StringRef RemarkName) { + return ::createMissedAnalysis(Hints->vectorizeAnalysisPassName(), + RemarkName, TheLoop); } -public: /// Map of scalar integer values to the smallest bitwidth they can be legally /// represented as. The vector equivalents of these values should be truncated /// to this type. MapVector<Instruction *, uint64_t> MinBWs; + /// A type representing the costs for instructions if they were to be + /// scalarized rather than vectorized. The entries are Instruction-Cost + /// pairs. + typedef DenseMap<Instruction *, unsigned> ScalarCostsTy; + + /// A map holding scalar costs for different vectorization factors. The + /// presence of a cost for an instruction in the mapping indicates that the + /// instruction will be scalarized when vectorizing with the associated + /// vectorization factor. The entries are VF-ScalarCostTy pairs. + DenseMap<unsigned, ScalarCostsTy> InstsToScalarize; + + /// Returns the expected difference in cost from scalarizing the expression + /// feeding a predicated instruction \p PredInst. The instructions to + /// scalarize and their scalar costs are collected in \p ScalarCosts. A + /// non-negative return value implies the expression will be scalarized. + /// Currently, only single-use chains are considered for scalarization. + int computePredInstDiscount(Instruction *PredInst, ScalarCostsTy &ScalarCosts, + unsigned VF); + + /// Collects the instructions to scalarize for each predicated instruction in + /// the loop. + void collectInstsToScalarize(unsigned VF); + +public: /// The loop that we evaluate. Loop *TheLoop; /// Predicated scalar evolution analysis. @@ -1790,6 +2007,9 @@ public: DemandedBits *DB; /// Assumption cache. AssumptionCache *AC; + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter *ORE; + const Function *TheFunction; /// Loop Vectorize Hint. const LoopVectorizeHints *Hints; @@ -1813,8 +2033,8 @@ public: /// followed by a non-expert user. class LoopVectorizationRequirements { public: - LoopVectorizationRequirements() - : NumRuntimePointerChecks(0), UnsafeAlgebraInst(nullptr) {} + LoopVectorizationRequirements(OptimizationRemarkEmitter &ORE) + : NumRuntimePointerChecks(0), UnsafeAlgebraInst(nullptr), ORE(ORE) {} void addUnsafeAlgebraInst(Instruction *I) { // First unsafe algebra instruction. @@ -1825,13 +2045,15 @@ public: void addRuntimePointerChecks(unsigned Num) { NumRuntimePointerChecks = Num; } bool doesNotMeet(Function *F, Loop *L, const LoopVectorizeHints &Hints) { - const char *Name = Hints.vectorizeAnalysisPassName(); + const char *PassName = Hints.vectorizeAnalysisPassName(); bool Failed = false; if (UnsafeAlgebraInst && !Hints.allowReordering()) { - emitOptimizationRemarkAnalysisFPCommute( - F->getContext(), Name, *F, UnsafeAlgebraInst->getDebugLoc(), - VectorizationReport() << "cannot prove it is safe to reorder " - "floating-point operations"); + ORE.emit( + OptimizationRemarkAnalysisFPCommute(PassName, "CantReorderFPOps", + UnsafeAlgebraInst->getDebugLoc(), + UnsafeAlgebraInst->getParent()) + << "loop not vectorized: cannot prove it is safe to reorder " + "floating-point operations"); Failed = true; } @@ -1842,10 +2064,11 @@ public: NumRuntimePointerChecks > VectorizerParams::RuntimeMemoryCheckThreshold; if ((ThresholdReached && !Hints.allowReordering()) || PragmaThresholdReached) { - emitOptimizationRemarkAnalysisAliasing( - F->getContext(), Name, *F, L->getStartLoc(), - VectorizationReport() - << "cannot prove it is safe to reorder memory operations"); + ORE.emit(OptimizationRemarkAnalysisAliasing(PassName, "CantReorderMemOps", + L->getStartLoc(), + L->getHeader()) + << "loop not vectorized: cannot prove it is safe to reorder " + "memory operations"); DEBUG(dbgs() << "LV: Too many memory checks needed.\n"); Failed = true; } @@ -1856,6 +2079,9 @@ public: private: unsigned NumRuntimePointerChecks; Instruction *UnsafeAlgebraInst; + + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter &ORE; }; static void addAcyclicInnerLoop(Loop &L, SmallVectorImpl<Loop *> &V) { @@ -1897,12 +2123,13 @@ struct LoopVectorize : public FunctionPass { auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); auto *DB = &getAnalysis<DemandedBitsWrapperPass>().getDemandedBits(); + auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); std::function<const LoopAccessInfo &(Loop &)> GetLAA = [&](Loop &L) -> const LoopAccessInfo & { return LAA->getInfo(&L); }; return Impl.runImpl(F, *SE, *LI, *TTI, *DT, *BFI, TLI, *DB, *AA, *AC, - GetLAA); + GetLAA, *ORE); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -1917,6 +2144,7 @@ struct LoopVectorize : public FunctionPass { AU.addRequired<AAResultsWrapperPass>(); AU.addRequired<LoopAccessLegacyAnalysis>(); AU.addRequired<DemandedBitsWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); AU.addPreserved<LoopInfoWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<BasicAAWrapperPass>(); @@ -1949,7 +2177,7 @@ Value *InnerLoopVectorizer::getBroadcastInstrs(Value *V) { } void InnerLoopVectorizer::createVectorIntInductionPHI( - const InductionDescriptor &II, VectorParts &Entry, IntegerType *TruncType) { + const InductionDescriptor &II, Instruction *EntryVal) { Value *Start = II.getStartValue(); ConstantInt *Step = II.getConstIntStepValue(); assert(Step && "Can not widen an IV with a non-constant step"); @@ -1957,7 +2185,8 @@ void InnerLoopVectorizer::createVectorIntInductionPHI( // Construct the initial value of the vector IV in the vector loop preheader auto CurrIP = Builder.saveIP(); Builder.SetInsertPoint(LoopVectorPreHeader->getTerminator()); - if (TruncType) { + if (isa<TruncInst>(EntryVal)) { + auto *TruncType = cast<IntegerType>(EntryVal->getType()); Step = ConstantInt::getSigned(TruncType, Step->getSExtValue()); Start = Builder.CreateCast(Instruction::Trunc, Start, TruncType); } @@ -1972,18 +2201,45 @@ void InnerLoopVectorizer::createVectorIntInductionPHI( // factor. The last of those goes into the PHI. PHINode *VecInd = PHINode::Create(SteppedStart->getType(), 2, "vec.ind", &*LoopVectorBody->getFirstInsertionPt()); - Value *LastInduction = VecInd; + Instruction *LastInduction = VecInd; + VectorParts Entry(UF); for (unsigned Part = 0; Part < UF; ++Part) { Entry[Part] = LastInduction; - LastInduction = Builder.CreateAdd(LastInduction, SplatVF, "step.add"); + LastInduction = cast<Instruction>( + Builder.CreateAdd(LastInduction, SplatVF, "step.add")); } + VectorLoopValueMap.initVector(EntryVal, Entry); + if (isa<TruncInst>(EntryVal)) + addMetadata(Entry, EntryVal); + + // Move the last step to the end of the latch block. This ensures consistent + // placement of all induction updates. + auto *LoopVectorLatch = LI->getLoopFor(LoopVectorBody)->getLoopLatch(); + auto *Br = cast<BranchInst>(LoopVectorLatch->getTerminator()); + auto *ICmp = cast<Instruction>(Br->getCondition()); + LastInduction->moveBefore(ICmp); + LastInduction->setName("vec.ind.next"); VecInd->addIncoming(SteppedStart, LoopVectorPreHeader); - VecInd->addIncoming(LastInduction, LoopVectorBody); + VecInd->addIncoming(LastInduction, LoopVectorLatch); } -void InnerLoopVectorizer::widenIntInduction(PHINode *IV, VectorParts &Entry, - TruncInst *Trunc) { +bool InnerLoopVectorizer::shouldScalarizeInstruction(Instruction *I) const { + return Legal->isScalarAfterVectorization(I) || + Cost->isProfitableToScalarize(I, VF); +} + +bool InnerLoopVectorizer::needsScalarInduction(Instruction *IV) const { + if (shouldScalarizeInstruction(IV)) + return true; + auto isScalarInst = [&](User *U) -> bool { + auto *I = cast<Instruction>(U); + return (OrigLoop->contains(I) && shouldScalarizeInstruction(I)); + }; + return any_of(IV->users(), isScalarInst); +} + +void InnerLoopVectorizer::widenIntInduction(PHINode *IV, TruncInst *Trunc) { auto II = Legal->getInductionVars()->find(IV); assert(II != Legal->getInductionVars()->end() && "IV is not an induction"); @@ -1991,12 +2247,25 @@ void InnerLoopVectorizer::widenIntInduction(PHINode *IV, VectorParts &Entry, auto ID = II->second; assert(IV->getType() == ID.getStartValue()->getType() && "Types must match"); - // If a truncate instruction was provided, get the smaller type. - auto *TruncType = Trunc ? cast<IntegerType>(Trunc->getType()) : nullptr; + // The scalar value to broadcast. This will be derived from the canonical + // induction variable. + Value *ScalarIV = nullptr; // The step of the induction. Value *Step = nullptr; + // The value from the original loop to which we are mapping the new induction + // variable. + Instruction *EntryVal = Trunc ? cast<Instruction>(Trunc) : IV; + + // True if we have vectorized the induction variable. + auto VectorizedIV = false; + + // Determine if we want a scalar version of the induction variable. This is + // true if the induction variable itself is not widened, or if it has at + // least one user in the loop that is not widened. + auto NeedsScalarIV = VF > 1 && needsScalarInduction(EntryVal); + // If the induction variable has a constant integer step value, go ahead and // get it now. if (ID.getConstIntStepValue()) @@ -2006,40 +2275,50 @@ void InnerLoopVectorizer::widenIntInduction(PHINode *IV, VectorParts &Entry, // create the phi node, we will splat the scalar induction variable in each // loop iteration. if (VF > 1 && IV->getType() == Induction->getType() && Step && - !ValuesNotWidened->count(IV)) - return createVectorIntInductionPHI(ID, Entry, TruncType); - - // The scalar value to broadcast. This will be derived from the canonical - // induction variable. - Value *ScalarIV = nullptr; - - // Define the scalar induction variable and step values. If we were given a - // truncation type, truncate the canonical induction variable and constant - // step. Otherwise, derive these values from the induction descriptor. - if (TruncType) { - assert(Step && "Truncation requires constant integer step"); - auto StepInt = cast<ConstantInt>(Step)->getSExtValue(); - ScalarIV = Builder.CreateCast(Instruction::Trunc, Induction, TruncType); - Step = ConstantInt::getSigned(TruncType, StepInt); - } else { - ScalarIV = Induction; - auto &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); - if (IV != OldInduction) { - ScalarIV = Builder.CreateSExtOrTrunc(ScalarIV, IV->getType()); - ScalarIV = ID.transform(Builder, ScalarIV, PSE.getSE(), DL); - ScalarIV->setName("offset.idx"); - } - if (!Step) { - SCEVExpander Exp(*PSE.getSE(), DL, "induction"); - Step = Exp.expandCodeFor(ID.getStep(), ID.getStep()->getType(), - &*Builder.GetInsertPoint()); + !shouldScalarizeInstruction(EntryVal)) { + createVectorIntInductionPHI(ID, EntryVal); + VectorizedIV = true; + } + + // If we haven't yet vectorized the induction variable, or if we will create + // a scalar one, we need to define the scalar induction variable and step + // values. If we were given a truncation type, truncate the canonical + // induction variable and constant step. Otherwise, derive these values from + // the induction descriptor. + if (!VectorizedIV || NeedsScalarIV) { + if (Trunc) { + auto *TruncType = cast<IntegerType>(Trunc->getType()); + assert(Step && "Truncation requires constant integer step"); + auto StepInt = cast<ConstantInt>(Step)->getSExtValue(); + ScalarIV = Builder.CreateCast(Instruction::Trunc, Induction, TruncType); + Step = ConstantInt::getSigned(TruncType, StepInt); + } else { + ScalarIV = Induction; + auto &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); + if (IV != OldInduction) { + ScalarIV = Builder.CreateSExtOrTrunc(ScalarIV, IV->getType()); + ScalarIV = ID.transform(Builder, ScalarIV, PSE.getSE(), DL); + ScalarIV->setName("offset.idx"); + } + if (!Step) { + SCEVExpander Exp(*PSE.getSE(), DL, "induction"); + Step = Exp.expandCodeFor(ID.getStep(), ID.getStep()->getType(), + &*Builder.GetInsertPoint()); + } } } - // Splat the scalar induction variable, and build the necessary step vectors. - Value *Broadcasted = getBroadcastInstrs(ScalarIV); - for (unsigned Part = 0; Part < UF; ++Part) - Entry[Part] = getStepVector(Broadcasted, VF * Part, Step); + // If we haven't yet vectorized the induction variable, splat the scalar + // induction variable, and build the necessary step vectors. + if (!VectorizedIV) { + Value *Broadcasted = getBroadcastInstrs(ScalarIV); + VectorParts Entry(UF); + for (unsigned Part = 0; Part < UF; ++Part) + Entry[Part] = getStepVector(Broadcasted, VF * Part, Step); + VectorLoopValueMap.initVector(EntryVal, Entry); + if (Trunc) + addMetadata(Entry, Trunc); + } // If an induction variable is only used for counting loop iterations or // calculating addresses, it doesn't need to be widened. Create scalar steps @@ -2047,38 +2326,64 @@ void InnerLoopVectorizer::widenIntInduction(PHINode *IV, VectorParts &Entry, // addition of the scalar steps will not increase the number of instructions // in the loop in the common case prior to InstCombine. We will be trading // one vector extract for each scalar step. - if (VF > 1 && ValuesNotWidened->count(IV)) { - auto *EntryVal = Trunc ? cast<Value>(Trunc) : IV; + if (NeedsScalarIV) buildScalarSteps(ScalarIV, Step, EntryVal); - } } -Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, - Value *Step) { +Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step, + Instruction::BinaryOps BinOp) { + // Create and check the types. assert(Val->getType()->isVectorTy() && "Must be a vector"); - assert(Val->getType()->getScalarType()->isIntegerTy() && - "Elem must be an integer"); - assert(Step->getType() == Val->getType()->getScalarType() && - "Step has wrong type"); - // Create the types. - Type *ITy = Val->getType()->getScalarType(); - VectorType *Ty = cast<VectorType>(Val->getType()); - int VLen = Ty->getNumElements(); + int VLen = Val->getType()->getVectorNumElements(); + + Type *STy = Val->getType()->getScalarType(); + assert((STy->isIntegerTy() || STy->isFloatingPointTy()) && + "Induction Step must be an integer or FP"); + assert(Step->getType() == STy && "Step has wrong type"); + SmallVector<Constant *, 8> Indices; + if (STy->isIntegerTy()) { + // Create a vector of consecutive numbers from zero to VF. + for (int i = 0; i < VLen; ++i) + Indices.push_back(ConstantInt::get(STy, StartIdx + i)); + + // Add the consecutive indices to the vector value. + Constant *Cv = ConstantVector::get(Indices); + assert(Cv->getType() == Val->getType() && "Invalid consecutive vec"); + Step = Builder.CreateVectorSplat(VLen, Step); + assert(Step->getType() == Val->getType() && "Invalid step vec"); + // FIXME: The newly created binary instructions should contain nsw/nuw flags, + // which can be found from the original scalar operations. + Step = Builder.CreateMul(Cv, Step); + return Builder.CreateAdd(Val, Step, "induction"); + } + + // Floating point induction. + assert((BinOp == Instruction::FAdd || BinOp == Instruction::FSub) && + "Binary Opcode should be specified for FP induction"); // Create a vector of consecutive numbers from zero to VF. for (int i = 0; i < VLen; ++i) - Indices.push_back(ConstantInt::get(ITy, StartIdx + i)); + Indices.push_back(ConstantFP::get(STy, (double)(StartIdx + i))); // Add the consecutive indices to the vector value. Constant *Cv = ConstantVector::get(Indices); - assert(Cv->getType() == Val->getType() && "Invalid consecutive vec"); + Step = Builder.CreateVectorSplat(VLen, Step); - assert(Step->getType() == Val->getType() && "Invalid step vec"); - // FIXME: The newly created binary instructions should contain nsw/nuw flags, - // which can be found from the original scalar operations. - Step = Builder.CreateMul(Cv, Step); - return Builder.CreateAdd(Val, Step, "induction"); + + // Floating point operations had to be 'fast' to enable the induction. + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + + Value *MulOp = Builder.CreateFMul(Cv, Step); + if (isa<Instruction>(MulOp)) + // Have to check, MulOp may be a constant + cast<Instruction>(MulOp)->setFastMathFlags(Flags); + + Value *BOp = Builder.CreateBinOp(BinOp, Val, MulOp, "induction"); + if (isa<Instruction>(BOp)) + cast<Instruction>(BOp)->setFastMathFlags(Flags); + return BOp; } void InnerLoopVectorizer::buildScalarSteps(Value *ScalarIV, Value *Step, @@ -2092,98 +2397,34 @@ void InnerLoopVectorizer::buildScalarSteps(Value *ScalarIV, Value *Step, assert(ScalarIVTy->isIntegerTy() && ScalarIVTy == Step->getType() && "Val and Step should have the same integer type"); - // Compute the scalar steps and save the results in ScalarIVMap. - for (unsigned Part = 0; Part < UF; ++Part) - for (unsigned I = 0; I < VF; ++I) { - auto *StartIdx = ConstantInt::get(ScalarIVTy, VF * Part + I); + // Determine the number of scalars we need to generate for each unroll + // iteration. If EntryVal is uniform, we only need to generate the first + // lane. Otherwise, we generate all VF values. + unsigned Lanes = + Legal->isUniformAfterVectorization(cast<Instruction>(EntryVal)) ? 1 : VF; + + // Compute the scalar steps and save the results in VectorLoopValueMap. + ScalarParts Entry(UF); + for (unsigned Part = 0; Part < UF; ++Part) { + Entry[Part].resize(VF); + for (unsigned Lane = 0; Lane < Lanes; ++Lane) { + auto *StartIdx = ConstantInt::get(ScalarIVTy, VF * Part + Lane); auto *Mul = Builder.CreateMul(StartIdx, Step); auto *Add = Builder.CreateAdd(ScalarIV, Mul); - ScalarIVMap[EntryVal].push_back(Add); + Entry[Part][Lane] = Add; } + } + VectorLoopValueMap.initScalar(EntryVal, Entry); } int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { - assert(Ptr->getType()->isPointerTy() && "Unexpected non-ptr"); - auto *SE = PSE.getSE(); - // Make sure that the pointer does not point to structs. - if (Ptr->getType()->getPointerElementType()->isAggregateType()) - return 0; - - // If this value is a pointer induction variable, we know it is consecutive. - PHINode *Phi = dyn_cast_or_null<PHINode>(Ptr); - if (Phi && Inductions.count(Phi)) { - InductionDescriptor II = Inductions[Phi]; - return II.getConsecutiveDirection(); - } - - GetElementPtrInst *Gep = getGEPInstruction(Ptr); - if (!Gep) - return 0; - - unsigned NumOperands = Gep->getNumOperands(); - Value *GpPtr = Gep->getPointerOperand(); - // If this GEP value is a consecutive pointer induction variable and all of - // the indices are constant, then we know it is consecutive. - Phi = dyn_cast<PHINode>(GpPtr); - if (Phi && Inductions.count(Phi)) { - - // Make sure that the pointer does not point to structs. - PointerType *GepPtrType = cast<PointerType>(GpPtr->getType()); - if (GepPtrType->getElementType()->isAggregateType()) - return 0; - - // Make sure that all of the index operands are loop invariant. - for (unsigned i = 1; i < NumOperands; ++i) - if (!SE->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)), TheLoop)) - return 0; - - InductionDescriptor II = Inductions[Phi]; - return II.getConsecutiveDirection(); - } - - unsigned InductionOperand = getGEPInductionOperand(Gep); - - // Check that all of the gep indices are uniform except for our induction - // operand. - for (unsigned i = 0; i != NumOperands; ++i) - if (i != InductionOperand && - !SE->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)), TheLoop)) - return 0; - - // We can emit wide load/stores only if the last non-zero index is the - // induction variable. - const SCEV *Last = nullptr; - if (!getSymbolicStrides() || !getSymbolicStrides()->count(Gep)) - Last = PSE.getSCEV(Gep->getOperand(InductionOperand)); - else { - // Because of the multiplication by a stride we can have a s/zext cast. - // We are going to replace this stride by 1 so the cast is safe to ignore. - // - // %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] - // %0 = trunc i64 %indvars.iv to i32 - // %mul = mul i32 %0, %Stride1 - // %idxprom = zext i32 %mul to i64 << Safe cast. - // %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom - // - Last = replaceSymbolicStrideSCEV(PSE, *getSymbolicStrides(), - Gep->getOperand(InductionOperand), Gep); - if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(Last)) - Last = - (C->getSCEVType() == scSignExtend || C->getSCEVType() == scZeroExtend) - ? C->getOperand() - : Last; - } - if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Last)) { - const SCEV *Step = AR->getStepRecurrence(*SE); - // The memory is consecutive because the last index is consecutive - // and all other indices are loop invariant. - if (Step->isOne()) - return 1; - if (Step->isAllOnesValue()) - return -1; - } + const ValueToValueMap &Strides = getSymbolicStrides() ? *getSymbolicStrides() : + ValueToValueMap(); + int Stride = getPtrStride(PSE, Ptr, TheLoop, Strides, true, false); + if (Stride == 1 || Stride == -1) + return Stride; return 0; } @@ -2191,23 +2432,112 @@ bool LoopVectorizationLegality::isUniform(Value *V) { return LAI->isUniform(V); } -InnerLoopVectorizer::VectorParts & +const InnerLoopVectorizer::VectorParts & InnerLoopVectorizer::getVectorValue(Value *V) { assert(V != Induction && "The new induction variable should not be used."); assert(!V->getType()->isVectorTy() && "Can't widen a vector"); + assert(!V->getType()->isVoidTy() && "Type does not produce a value"); // If we have a stride that is replaced by one, do it here. if (Legal->hasStride(V)) V = ConstantInt::get(V->getType(), 1); // If we have this scalar in the map, return it. - if (WidenMap.has(V)) - return WidenMap.get(V); + if (VectorLoopValueMap.hasVector(V)) + return VectorLoopValueMap.VectorMapStorage[V]; + + // If the value has not been vectorized, check if it has been scalarized + // instead. If it has been scalarized, and we actually need the value in + // vector form, we will construct the vector values on demand. + if (VectorLoopValueMap.hasScalar(V)) { + + // Initialize a new vector map entry. + VectorParts Entry(UF); + + // If we've scalarized a value, that value should be an instruction. + auto *I = cast<Instruction>(V); + + // If we aren't vectorizing, we can just copy the scalar map values over to + // the vector map. + if (VF == 1) { + for (unsigned Part = 0; Part < UF; ++Part) + Entry[Part] = getScalarValue(V, Part, 0); + return VectorLoopValueMap.initVector(V, Entry); + } + + // Get the last scalar instruction we generated for V. If the value is + // known to be uniform after vectorization, this corresponds to lane zero + // of the last unroll iteration. Otherwise, the last instruction is the one + // we created for the last vector lane of the last unroll iteration. + unsigned LastLane = Legal->isUniformAfterVectorization(I) ? 0 : VF - 1; + auto *LastInst = cast<Instruction>(getScalarValue(V, UF - 1, LastLane)); + + // Set the insert point after the last scalarized instruction. This ensures + // the insertelement sequence will directly follow the scalar definitions. + auto OldIP = Builder.saveIP(); + auto NewIP = std::next(BasicBlock::iterator(LastInst)); + Builder.SetInsertPoint(&*NewIP); + + // However, if we are vectorizing, we need to construct the vector values. + // If the value is known to be uniform after vectorization, we can just + // broadcast the scalar value corresponding to lane zero for each unroll + // iteration. Otherwise, we construct the vector values using insertelement + // instructions. Since the resulting vectors are stored in + // VectorLoopValueMap, we will only generate the insertelements once. + for (unsigned Part = 0; Part < UF; ++Part) { + Value *VectorValue = nullptr; + if (Legal->isUniformAfterVectorization(I)) { + VectorValue = getBroadcastInstrs(getScalarValue(V, Part, 0)); + } else { + VectorValue = UndefValue::get(VectorType::get(V->getType(), VF)); + for (unsigned Lane = 0; Lane < VF; ++Lane) + VectorValue = Builder.CreateInsertElement( + VectorValue, getScalarValue(V, Part, Lane), + Builder.getInt32(Lane)); + } + Entry[Part] = VectorValue; + } + Builder.restoreIP(OldIP); + return VectorLoopValueMap.initVector(V, Entry); + } // If this scalar is unknown, assume that it is a constant or that it is // loop invariant. Broadcast V and save the value for future uses. Value *B = getBroadcastInstrs(V); - return WidenMap.splat(V, B); + return VectorLoopValueMap.initVector(V, VectorParts(UF, B)); +} + +Value *InnerLoopVectorizer::getScalarValue(Value *V, unsigned Part, + unsigned Lane) { + + // If the value is not an instruction contained in the loop, it should + // already be scalar. + if (OrigLoop->isLoopInvariant(V)) + return V; + + assert(Lane > 0 ? !Legal->isUniformAfterVectorization(cast<Instruction>(V)) + : true && "Uniform values only have lane zero"); + + // If the value from the original loop has not been vectorized, it is + // represented by UF x VF scalar values in the new loop. Return the requested + // scalar value. + if (VectorLoopValueMap.hasScalar(V)) + return VectorLoopValueMap.ScalarMapStorage[V][Part][Lane]; + + // If the value has not been scalarized, get its entry in VectorLoopValueMap + // for the given unroll part. If this entry is not a vector type (i.e., the + // vectorization factor is one), there is no need to generate an + // extractelement instruction. + auto *U = getVectorValue(V)[Part]; + if (!U->getType()->isVectorTy()) { + assert(VF == 1 && "Value not scalarized has non-vector type"); + return U; + } + + // Otherwise, the value from the original loop has been vectorized and is + // represented by UF vector values. Extract and return the requested scalar + // value from the appropriate vector lane. + return Builder.CreateExtractElement(U, Builder.getInt32(Lane)); } Value *InnerLoopVectorizer::reverseVector(Value *Vec) { @@ -2355,7 +2685,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { LoadInst *LI = dyn_cast<LoadInst>(Instr); StoreInst *SI = dyn_cast<StoreInst>(Instr); - Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand(); + Value *Ptr = getPointerOperand(Instr); // Prepare for the vector type of the interleaved load/store. Type *ScalarTy = LI ? LI->getType() : SI->getValueOperand()->getType(); @@ -2365,15 +2695,20 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { // Prepare for the new pointers. setDebugLocFromInst(Builder, Ptr); - VectorParts &PtrParts = getVectorValue(Ptr); SmallVector<Value *, 2> NewPtrs; unsigned Index = Group->getIndex(Instr); + + // If the group is reverse, adjust the index to refer to the last vector lane + // instead of the first. We adjust the index from the first vector lane, + // rather than directly getting the pointer for lane VF - 1, because the + // pointer operand of the interleaved access is supposed to be uniform. For + // uniform instructions, we're only required to generate a value for the + // first vector lane in each unroll iteration. + if (Group->isReverse()) + Index += (VF - 1) * Group->getFactor(); + for (unsigned Part = 0; Part < UF; Part++) { - // Extract the pointer for current instruction from the pointer vector. A - // reverse access uses the pointer in the last lane. - Value *NewPtr = Builder.CreateExtractElement( - PtrParts[Part], - Group->isReverse() ? Builder.getInt32(VF - 1) : Builder.getInt32(0)); + Value *NewPtr = getScalarValue(Ptr, Part, 0); // Notice current instruction could be any index. Need to adjust the address // to the member of index 0. @@ -2397,20 +2732,30 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { // Vectorize the interleaved load group. if (LI) { + + // For each unroll part, create a wide load for the group. + SmallVector<Value *, 2> NewLoads; for (unsigned Part = 0; Part < UF; Part++) { - Instruction *NewLoadInstr = Builder.CreateAlignedLoad( + auto *NewLoad = Builder.CreateAlignedLoad( NewPtrs[Part], Group->getAlignment(), "wide.vec"); + addMetadata(NewLoad, Instr); + NewLoads.push_back(NewLoad); + } - for (unsigned i = 0; i < InterleaveFactor; i++) { - Instruction *Member = Group->getMember(i); + // For each member in the group, shuffle out the appropriate data from the + // wide loads. + for (unsigned I = 0; I < InterleaveFactor; ++I) { + Instruction *Member = Group->getMember(I); - // Skip the gaps in the group. - if (!Member) - continue; + // Skip the gaps in the group. + if (!Member) + continue; - Constant *StrideMask = getStridedMask(Builder, i, InterleaveFactor, VF); + VectorParts Entry(UF); + Constant *StrideMask = getStridedMask(Builder, I, InterleaveFactor, VF); + for (unsigned Part = 0; Part < UF; Part++) { Value *StridedVec = Builder.CreateShuffleVector( - NewLoadInstr, UndefVec, StrideMask, "strided.vec"); + NewLoads[Part], UndefVec, StrideMask, "strided.vec"); // If this member has different type, cast the result type. if (Member->getType() != ScalarTy) { @@ -2418,12 +2763,10 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { StridedVec = Builder.CreateBitOrPointerCast(StridedVec, OtherVTy); } - VectorParts &Entry = WidenMap.get(Member); Entry[Part] = Group->isReverse() ? reverseVector(StridedVec) : StridedVec; } - - addMetadata(NewLoadInstr, Instr); + VectorLoopValueMap.initVector(Member, Entry); } return; } @@ -2479,7 +2822,7 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { Type *ScalarDataTy = LI ? LI->getType() : SI->getValueOperand()->getType(); Type *DataTy = VectorType::get(ScalarDataTy, VF); - Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand(); + Value *Ptr = getPointerOperand(Instr); unsigned Alignment = LI ? LI->getAlignment() : SI->getAlignment(); // An alignment of 0 means target abi alignment. We need to use the scalar's // target abi alignment in such a case. @@ -2487,93 +2830,57 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { if (!Alignment) Alignment = DL.getABITypeAlignment(ScalarDataTy); unsigned AddressSpace = Ptr->getType()->getPointerAddressSpace(); - uint64_t ScalarAllocatedSize = DL.getTypeAllocSize(ScalarDataTy); - uint64_t VectorElementSize = DL.getTypeStoreSize(DataTy) / VF; - if (SI && Legal->blockNeedsPredication(SI->getParent()) && - !Legal->isMaskRequired(SI)) - return scalarizeInstruction(Instr, true); + // Scalarize the memory instruction if necessary. + if (Legal->memoryInstructionMustBeScalarized(Instr, VF)) + return scalarizeInstruction(Instr, Legal->isScalarWithPredication(Instr)); - if (ScalarAllocatedSize != VectorElementSize) - return scalarizeInstruction(Instr); - - // If the pointer is loop invariant scalarize the load. - if (LI && Legal->isUniform(Ptr)) - return scalarizeInstruction(Instr); - - // If the pointer is non-consecutive and gather/scatter is not supported - // scalarize the instruction. + // Determine if the pointer operand of the access is either consecutive or + // reverse consecutive. int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); bool Reverse = ConsecutiveStride < 0; - bool CreateGatherScatter = - !ConsecutiveStride && ((LI && Legal->isLegalMaskedGather(ScalarDataTy)) || - (SI && Legal->isLegalMaskedScatter(ScalarDataTy))); - if (!ConsecutiveStride && !CreateGatherScatter) - return scalarizeInstruction(Instr); + // Determine if either a gather or scatter operation is legal. + bool CreateGatherScatter = + !ConsecutiveStride && Legal->isLegalGatherOrScatter(Instr); - Constant *Zero = Builder.getInt32(0); - VectorParts &Entry = WidenMap.get(Instr); VectorParts VectorGep; // Handle consecutive loads/stores. GetElementPtrInst *Gep = getGEPInstruction(Ptr); if (ConsecutiveStride) { - if (Gep && Legal->isInductionVariable(Gep->getPointerOperand())) { - setDebugLocFromInst(Builder, Gep); - Value *PtrOperand = Gep->getPointerOperand(); - Value *FirstBasePtr = getVectorValue(PtrOperand)[0]; - FirstBasePtr = Builder.CreateExtractElement(FirstBasePtr, Zero); - - // Create the new GEP with the new induction variable. - GetElementPtrInst *Gep2 = cast<GetElementPtrInst>(Gep->clone()); - Gep2->setOperand(0, FirstBasePtr); - Gep2->setName("gep.indvar.base"); - Ptr = Builder.Insert(Gep2); - } else if (Gep) { - setDebugLocFromInst(Builder, Gep); - assert(PSE.getSE()->isLoopInvariant(PSE.getSCEV(Gep->getPointerOperand()), - OrigLoop) && - "Base ptr must be invariant"); - // The last index does not have to be the induction. It can be - // consecutive and be a function of the index. For example A[I+1]; + if (Gep) { unsigned NumOperands = Gep->getNumOperands(); - unsigned InductionOperand = getGEPInductionOperand(Gep); - // Create the new GEP with the new induction variable. +#ifndef NDEBUG + // The original GEP that identified as a consecutive memory access + // should have only one loop-variant operand. + unsigned NumOfLoopVariantOps = 0; + for (unsigned i = 0; i < NumOperands; ++i) + if (!PSE.getSE()->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)), + OrigLoop)) + NumOfLoopVariantOps++; + assert(NumOfLoopVariantOps == 1 && + "Consecutive GEP should have only one loop-variant operand"); +#endif GetElementPtrInst *Gep2 = cast<GetElementPtrInst>(Gep->clone()); - - for (unsigned i = 0; i < NumOperands; ++i) { - Value *GepOperand = Gep->getOperand(i); - Instruction *GepOperandInst = dyn_cast<Instruction>(GepOperand); - - // Update last index or loop invariant instruction anchored in loop. - if (i == InductionOperand || - (GepOperandInst && OrigLoop->contains(GepOperandInst))) { - assert((i == InductionOperand || - PSE.getSE()->isLoopInvariant(PSE.getSCEV(GepOperandInst), - OrigLoop)) && - "Must be last index or loop invariant"); - - VectorParts &GEPParts = getVectorValue(GepOperand); - - // If GepOperand is an induction variable, and there's a scalarized - // version of it available, use it. Otherwise, we will need to create - // an extractelement instruction. - Value *Index = ScalarIVMap.count(GepOperand) - ? ScalarIVMap[GepOperand][0] - : Builder.CreateExtractElement(GEPParts[0], Zero); - - Gep2->setOperand(i, Index); - Gep2->setName("gep.indvar.idx"); - } - } + Gep2->setName("gep.indvar"); + + // A new GEP is created for a 0-lane value of the first unroll iteration. + // The GEPs for the rest of the unroll iterations are computed below as an + // offset from this GEP. + for (unsigned i = 0; i < NumOperands; ++i) + // We can apply getScalarValue() for all GEP indices. It returns an + // original value for loop-invariant operand and 0-lane for consecutive + // operand. + Gep2->setOperand(i, getScalarValue(Gep->getOperand(i), + 0, /* First unroll iteration */ + 0 /* 0-lane of the vector */ )); + setDebugLocFromInst(Builder, Gep); Ptr = Builder.Insert(Gep2); + } else { // No GEP - // Use the induction element ptr. - assert(isa<PHINode>(Ptr) && "Invalid induction ptr"); setDebugLocFromInst(Builder, Ptr); - VectorParts &PtrVal = getVectorValue(Ptr); - Ptr = Builder.CreateExtractElement(PtrVal[0], Zero); + Ptr = getScalarValue(Ptr, 0, 0); } } else { // At this point we should vector version of GEP for Gather or Scatter @@ -2660,6 +2967,7 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { // Handle loads. assert(LI && "Must have a load instruction"); setDebugLocFromInst(Builder, LI); + VectorParts Entry(UF); for (unsigned Part = 0; Part < UF; ++Part) { Instruction *NewLI; if (CreateGatherScatter) { @@ -2692,70 +3000,45 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { } addMetadata(NewLI, LI); } + VectorLoopValueMap.initVector(Instr, Entry); } void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, - bool IfPredicateStore) { + bool IfPredicateInstr) { assert(!Instr->getType()->isAggregateType() && "Can't handle vectors"); + DEBUG(dbgs() << "LV: Scalarizing" + << (IfPredicateInstr ? " and predicating:" : ":") << *Instr + << '\n'); // Holds vector parameters or scalars, in case of uniform vals. SmallVector<VectorParts, 4> Params; setDebugLocFromInst(Builder, Instr); - // Find all of the vectorized parameters. - for (Value *SrcOp : Instr->operands()) { - // If we are accessing the old induction variable, use the new one. - if (SrcOp == OldInduction) { - Params.push_back(getVectorValue(SrcOp)); - continue; - } - - // Try using previously calculated values. - auto *SrcInst = dyn_cast<Instruction>(SrcOp); - - // If the src is an instruction that appeared earlier in the basic block, - // then it should already be vectorized. - if (SrcInst && OrigLoop->contains(SrcInst)) { - assert(WidenMap.has(SrcInst) && "Source operand is unavailable"); - // The parameter is a vector value from earlier. - Params.push_back(WidenMap.get(SrcInst)); - } else { - // The parameter is a scalar from outside the loop. Maybe even a constant. - VectorParts Scalars; - Scalars.append(UF, SrcOp); - Params.push_back(Scalars); - } - } - - assert(Params.size() == Instr->getNumOperands() && - "Invalid number of operands"); - // Does this instruction return a value ? bool IsVoidRetTy = Instr->getType()->isVoidTy(); - Value *UndefVec = - IsVoidRetTy ? nullptr - : UndefValue::get(VectorType::get(Instr->getType(), VF)); - // Create a new entry in the WidenMap and initialize it to Undef or Null. - VectorParts &VecResults = WidenMap.splat(Instr, UndefVec); + // Initialize a new scalar map entry. + ScalarParts Entry(UF); VectorParts Cond; - if (IfPredicateStore) { - assert(Instr->getParent()->getSinglePredecessor() && - "Only support single predecessor blocks"); - Cond = createEdgeMask(Instr->getParent()->getSinglePredecessor(), - Instr->getParent()); - } + if (IfPredicateInstr) + Cond = createBlockInMask(Instr->getParent()); + + // Determine the number of scalars we need to generate for each unroll + // iteration. If the instruction is uniform, we only need to generate the + // first lane. Otherwise, we generate all VF values. + unsigned Lanes = Legal->isUniformAfterVectorization(Instr) ? 1 : VF; // For each vector unroll 'part': for (unsigned Part = 0; Part < UF; ++Part) { + Entry[Part].resize(VF); // For each scalar that we create: - for (unsigned Width = 0; Width < VF; ++Width) { + for (unsigned Lane = 0; Lane < Lanes; ++Lane) { // Start if-block. Value *Cmp = nullptr; - if (IfPredicateStore) { - Cmp = Builder.CreateExtractElement(Cond[Part], Builder.getInt32(Width)); + if (IfPredicateInstr) { + Cmp = Builder.CreateExtractElement(Cond[Part], Builder.getInt32(Lane)); Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Cmp, ConstantInt::get(Cmp->getType(), 1)); } @@ -2763,18 +3046,11 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, Instruction *Cloned = Instr->clone(); if (!IsVoidRetTy) Cloned->setName(Instr->getName() + ".cloned"); - // Replace the operands of the cloned instructions with extracted scalars. - for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { - // If the operand is an induction variable, and there's a scalarized - // version of it available, use it. Otherwise, we will need to create - // an extractelement instruction if vectorizing. - auto *NewOp = Params[op][Part]; - auto *ScalarOp = Instr->getOperand(op); - if (ScalarIVMap.count(ScalarOp)) - NewOp = ScalarIVMap[ScalarOp][VF * Part + Width]; - else if (NewOp->getType()->isVectorTy()) - NewOp = Builder.CreateExtractElement(NewOp, Builder.getInt32(Width)); + // Replace the operands of the cloned instructions with their scalar + // equivalents in the new loop. + for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { + auto *NewOp = getScalarValue(Instr->getOperand(op), Part, Lane); Cloned->setOperand(op, NewOp); } addNewMetadata(Cloned, Instr); @@ -2782,22 +3058,20 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, // Place the cloned scalar in the new loop. Builder.Insert(Cloned); + // Add the cloned scalar to the scalar map entry. + Entry[Part][Lane] = Cloned; + // If we just cloned a new assumption, add it the assumption cache. if (auto *II = dyn_cast<IntrinsicInst>(Cloned)) if (II->getIntrinsicID() == Intrinsic::assume) AC->registerAssumption(II); - // If the original scalar returns a value we need to place it in a vector - // so that future users will be able to use it. - if (!IsVoidRetTy) - VecResults[Part] = Builder.CreateInsertElement(VecResults[Part], Cloned, - Builder.getInt32(Width)); // End if-block. - if (IfPredicateStore) - PredicatedStores.push_back( - std::make_pair(cast<StoreInst>(Cloned), Cmp)); + if (IfPredicateInstr) + PredicatedInstructions.push_back(std::make_pair(Cloned, Cmp)); } } + VectorLoopValueMap.initScalar(Instr, Entry); } PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, Value *Start, @@ -2811,10 +3085,12 @@ PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, Value *Start, Latch = Header; IRBuilder<> Builder(&*Header->getFirstInsertionPt()); - setDebugLocFromInst(Builder, getDebugLocFromInstOrOperands(OldInduction)); + Instruction *OldInst = getDebugLocFromInstOrOperands(OldInduction); + setDebugLocFromInst(Builder, OldInst); auto *Induction = Builder.CreatePHI(Start->getType(), 2, "index"); Builder.SetInsertPoint(Latch->getTerminator()); + setDebugLocFromInst(Builder, OldInst); // Create i+1 and fill the PHINode. Value *Next = Builder.CreateAdd(Induction, Step, "index.next"); @@ -3146,14 +3422,16 @@ void InnerLoopVectorizer::createEmptyLoop() { // Create phi nodes to merge from the backedge-taken check block. PHINode *BCResumeVal = PHINode::Create( OrigPhi->getType(), 3, "bc.resume.val", ScalarPH->getTerminator()); - Value *EndValue; + Value *&EndValue = IVEndValues[OrigPhi]; if (OrigPhi == OldInduction) { // We know what the end value is. EndValue = CountRoundDown; } else { IRBuilder<> B(LoopBypassBlocks.back()->getTerminator()); - Value *CRD = B.CreateSExtOrTrunc(CountRoundDown, - II.getStep()->getType(), "cast.crd"); + Type *StepType = II.getStep()->getType(); + Instruction::CastOps CastOp = + CastInst::getCastOpcode(CountRoundDown, true, StepType, true); + Value *CRD = B.CreateCast(CastOp, CountRoundDown, StepType, "cast.crd"); const DataLayout &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); EndValue = II.transform(B, CRD, PSE.getSE(), DL); EndValue->setName("ind.end"); @@ -3163,9 +3441,6 @@ void InnerLoopVectorizer::createEmptyLoop() { // or the value at the end of the vectorized loop. BCResumeVal->addIncoming(EndValue, MiddleBlock); - // Fix up external users of the induction variable. - fixupIVUsers(OrigPhi, II, CountRoundDown, EndValue, MiddleBlock); - // Fix the scalar body counter (PHI node). unsigned BlockIdx = OrigPhi->getBasicBlockIndex(ScalarPH); @@ -3201,7 +3476,7 @@ void InnerLoopVectorizer::createEmptyLoop() { if (MDNode *LID = OrigLoop->getLoopID()) Lp->setLoopID(LID); - LoopVectorizeHints Hints(Lp, true); + LoopVectorizeHints Hints(Lp, true, *ORE); Hints.setAlreadyVectorized(); } @@ -3324,8 +3599,9 @@ static Value *addFastMathFlag(Value *V) { return V; } -/// Estimate the overhead of scalarizing a value. Insert and Extract are set if -/// the result needs to be inserted and/or extracted from vectors. +/// \brief Estimate the overhead of scalarizing a value based on its type. +/// Insert and Extract are set if the result needs to be inserted and/or +/// extracted from vectors. static unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract, const TargetTransformInfo &TTI) { if (Ty->isVoidTy()) @@ -3335,15 +3611,46 @@ static unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract, unsigned Cost = 0; for (unsigned I = 0, E = Ty->getVectorNumElements(); I < E; ++I) { - if (Insert) - Cost += TTI.getVectorInstrCost(Instruction::InsertElement, Ty, I); if (Extract) Cost += TTI.getVectorInstrCost(Instruction::ExtractElement, Ty, I); + if (Insert) + Cost += TTI.getVectorInstrCost(Instruction::InsertElement, Ty, I); } return Cost; } +/// \brief Estimate the overhead of scalarizing an Instruction based on the +/// types of its operands and return value. +static unsigned getScalarizationOverhead(SmallVectorImpl<Type *> &OpTys, + Type *RetTy, + const TargetTransformInfo &TTI) { + unsigned ScalarizationCost = + getScalarizationOverhead(RetTy, true, false, TTI); + + for (Type *Ty : OpTys) + ScalarizationCost += getScalarizationOverhead(Ty, false, true, TTI); + + return ScalarizationCost; +} + +/// \brief Estimate the overhead of scalarizing an instruction. This is a +/// convenience wrapper for the type-based getScalarizationOverhead API. +static unsigned getScalarizationOverhead(Instruction *I, unsigned VF, + const TargetTransformInfo &TTI) { + if (VF == 1) + return 0; + + Type *RetTy = ToVectorTy(I->getType(), VF); + + SmallVector<Type *, 4> OpTys; + unsigned OperandsNum = I->getNumOperands(); + for (unsigned OpInd = 0; OpInd < OperandsNum; ++OpInd) + OpTys.push_back(ToVectorTy(I->getOperand(OpInd)->getType(), VF)); + + return getScalarizationOverhead(OpTys, RetTy, TTI); +} + // Estimate cost of a call instruction CI if it were vectorized with factor VF. // Return the cost of the instruction, including scalarization overhead if it's // needed. The flag NeedToScalarize shows if the call needs to be scalarized - @@ -3374,10 +3681,7 @@ static unsigned getVectorCallCost(CallInst *CI, unsigned VF, // Compute costs of unpacking argument values for the scalar calls and // packing the return values to a vector. - unsigned ScalarizationCost = - getScalarizationOverhead(RetTy, true, false, TTI); - for (Type *Ty : Tys) - ScalarizationCost += getScalarizationOverhead(Ty, false, true, TTI); + unsigned ScalarizationCost = getScalarizationOverhead(Tys, RetTy, TTI); unsigned Cost = ScalarCallCost * VF + ScalarizationCost; @@ -3434,8 +3738,13 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() { // later and will remove any ext/trunc pairs. // SmallPtrSet<Value *, 4> Erased; - for (const auto &KV : *MinBWs) { - VectorParts &Parts = WidenMap.get(KV.first); + for (const auto &KV : Cost->getMinimalBitwidths()) { + // If the value wasn't vectorized, we must maintain the original scalar + // type. The absence of the value from VectorLoopValueMap indicates that it + // wasn't vectorized. + if (!VectorLoopValueMap.hasVector(KV.first)) + continue; + VectorParts &Parts = VectorLoopValueMap.getVector(KV.first); for (Value *&I : Parts) { if (Erased.count(I) || I->use_empty() || !isa<Instruction>(I)) continue; @@ -3526,8 +3835,13 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() { } // We'll have created a bunch of ZExts that are now parentless. Clean up. - for (const auto &KV : *MinBWs) { - VectorParts &Parts = WidenMap.get(KV.first); + for (const auto &KV : Cost->getMinimalBitwidths()) { + // If the value wasn't vectorized, we must maintain the original scalar + // type. The absence of the value from VectorLoopValueMap indicates that it + // wasn't vectorized. + if (!VectorLoopValueMap.hasVector(KV.first)) + continue; + VectorParts &Parts = VectorLoopValueMap.getVector(KV.first); for (Value *&I : Parts) { ZExtInst *Inst = dyn_cast<ZExtInst>(I); if (Inst && Inst->use_empty()) { @@ -3558,6 +3872,11 @@ void InnerLoopVectorizer::vectorizeLoop() { // are vectorized, so we can use them to construct the PHI. PhiVector PHIsToFix; + // Collect instructions from the original loop that will become trivially + // dead in the vectorized loop. We don't need to vectorize these + // instructions. + collectTriviallyDeadInstructions(); + // Scan the loop in a topological order to ensure that defs are vectorized // before users. LoopBlocksDFS DFS(OrigLoop); @@ -3605,7 +3924,7 @@ void InnerLoopVectorizer::vectorizeLoop() { Builder.SetInsertPoint(LoopBypassBlocks[1]->getTerminator()); // This is the vector-clone of the value that leaves the loop. - VectorParts &VectorExit = getVectorValue(LoopExitInst); + const VectorParts &VectorExit = getVectorValue(LoopExitInst); Type *VecTy = VectorExit[0]->getType(); // Find the reduction identity variable. Zero for addition, or, xor, @@ -3644,10 +3963,10 @@ void InnerLoopVectorizer::vectorizeLoop() { // Reductions do not have to start at zero. They can start with // any loop invariant values. - VectorParts &VecRdxPhi = WidenMap.get(Phi); + const VectorParts &VecRdxPhi = getVectorValue(Phi); BasicBlock *Latch = OrigLoop->getLoopLatch(); Value *LoopVal = Phi->getIncomingValueForBlock(Latch); - VectorParts &Val = getVectorValue(LoopVal); + const VectorParts &Val = getVectorValue(LoopVal); for (unsigned part = 0; part < UF; ++part) { // Make sure to add the reduction stat value only to the // first unroll part. @@ -3664,7 +3983,7 @@ void InnerLoopVectorizer::vectorizeLoop() { // instructions. Builder.SetInsertPoint(&*LoopMiddleBlock->getFirstInsertionPt()); - VectorParts RdxParts = getVectorValue(LoopExitInst); + VectorParts &RdxParts = VectorLoopValueMap.getVector(LoopExitInst); setDebugLocFromInst(Builder, LoopExitInst); // If the vector reduction can be performed in a smaller type, we truncate @@ -3792,22 +4111,25 @@ void InnerLoopVectorizer::vectorizeLoop() { Phi->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst); } // end of for each Phi in PHIsToFix. - fixLCSSAPHIs(); - - // Make sure DomTree is updated. + // Update the dominator tree. + // + // FIXME: After creating the structure of the new loop, the dominator tree is + // no longer up-to-date, and it remains that way until we update it + // here. An out-of-date dominator tree is problematic for SCEV, + // because SCEVExpander uses it to guide code generation. The + // vectorizer use SCEVExpanders in several places. Instead, we should + // keep the dominator tree up-to-date as we go. updateAnalysis(); - // Predicate any stores. - for (auto KV : PredicatedStores) { - BasicBlock::iterator I(KV.first); - auto *BB = SplitBlock(I->getParent(), &*std::next(I), DT, LI); - auto *T = SplitBlockAndInsertIfThen(KV.second, &*I, /*Unreachable=*/false, - /*BranchWeights=*/nullptr, DT, LI); - I->moveBefore(T); - I->getParent()->setName("pred.store.if"); - BB->setName("pred.store.continue"); - } - DEBUG(DT->verifyDomTree()); + // Fix-up external users of the induction variables. + for (auto &Entry : *Legal->getInductionVars()) + fixupIVUsers(Entry.first, Entry.second, + getOrCreateVectorTripCount(LI->getLoopFor(LoopVectorBody)), + IVEndValues[Entry.first], LoopMiddleBlock); + + fixLCSSAPHIs(); + predicateInstructions(); + // Remove redundant induction instructions. cse(LoopVectorBody); } @@ -3882,7 +4204,7 @@ void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) { // We constructed a temporary phi node in the first phase of vectorization. // This phi node will eventually be deleted. - auto &PhiParts = getVectorValue(Phi); + VectorParts &PhiParts = VectorLoopValueMap.getVector(Phi); Builder.SetInsertPoint(cast<Instruction>(PhiParts[0])); // Create a phi node for the new recurrence. The current value will either be @@ -3974,10 +4296,217 @@ void InnerLoopVectorizer::fixLCSSAPHIs() { } } +void InnerLoopVectorizer::collectTriviallyDeadInstructions() { + BasicBlock *Latch = OrigLoop->getLoopLatch(); + + // We create new control-flow for the vectorized loop, so the original + // condition will be dead after vectorization if it's only used by the + // branch. + auto *Cmp = dyn_cast<Instruction>(Latch->getTerminator()->getOperand(0)); + if (Cmp && Cmp->hasOneUse()) + DeadInstructions.insert(Cmp); + + // We create new "steps" for induction variable updates to which the original + // induction variables map. An original update instruction will be dead if + // all its users except the induction variable are dead. + for (auto &Induction : *Legal->getInductionVars()) { + PHINode *Ind = Induction.first; + auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); + if (all_of(IndUpdate->users(), [&](User *U) -> bool { + return U == Ind || DeadInstructions.count(cast<Instruction>(U)); + })) + DeadInstructions.insert(IndUpdate); + } +} + +void InnerLoopVectorizer::sinkScalarOperands(Instruction *PredInst) { + + // The basic block and loop containing the predicated instruction. + auto *PredBB = PredInst->getParent(); + auto *VectorLoop = LI->getLoopFor(PredBB); + + // Initialize a worklist with the operands of the predicated instruction. + SetVector<Value *> Worklist(PredInst->op_begin(), PredInst->op_end()); + + // Holds instructions that we need to analyze again. An instruction may be + // reanalyzed if we don't yet know if we can sink it or not. + SmallVector<Instruction *, 8> InstsToReanalyze; + + // Returns true if a given use occurs in the predicated block. Phi nodes use + // their operands in their corresponding predecessor blocks. + auto isBlockOfUsePredicated = [&](Use &U) -> bool { + auto *I = cast<Instruction>(U.getUser()); + BasicBlock *BB = I->getParent(); + if (auto *Phi = dyn_cast<PHINode>(I)) + BB = Phi->getIncomingBlock( + PHINode::getIncomingValueNumForOperand(U.getOperandNo())); + return BB == PredBB; + }; + + // Iteratively sink the scalarized operands of the predicated instruction + // into the block we created for it. When an instruction is sunk, it's + // operands are then added to the worklist. The algorithm ends after one pass + // through the worklist doesn't sink a single instruction. + bool Changed; + do { + + // Add the instructions that need to be reanalyzed to the worklist, and + // reset the changed indicator. + Worklist.insert(InstsToReanalyze.begin(), InstsToReanalyze.end()); + InstsToReanalyze.clear(); + Changed = false; + + while (!Worklist.empty()) { + auto *I = dyn_cast<Instruction>(Worklist.pop_back_val()); + + // We can't sink an instruction if it is a phi node, is already in the + // predicated block, is not in the loop, or may have side effects. + if (!I || isa<PHINode>(I) || I->getParent() == PredBB || + !VectorLoop->contains(I) || I->mayHaveSideEffects()) + continue; + + // It's legal to sink the instruction if all its uses occur in the + // predicated block. Otherwise, there's nothing to do yet, and we may + // need to reanalyze the instruction. + if (!all_of(I->uses(), isBlockOfUsePredicated)) { + InstsToReanalyze.push_back(I); + continue; + } + + // Move the instruction to the beginning of the predicated block, and add + // it's operands to the worklist. + I->moveBefore(&*PredBB->getFirstInsertionPt()); + Worklist.insert(I->op_begin(), I->op_end()); + + // The sinking may have enabled other instructions to be sunk, so we will + // need to iterate. + Changed = true; + } + } while (Changed); +} + +void InnerLoopVectorizer::predicateInstructions() { + + // For each instruction I marked for predication on value C, split I into its + // own basic block to form an if-then construct over C. Since I may be fed by + // an extractelement instruction or other scalar operand, we try to + // iteratively sink its scalar operands into the predicated block. If I feeds + // an insertelement instruction, we try to move this instruction into the + // predicated block as well. For non-void types, a phi node will be created + // for the resulting value (either vector or scalar). + // + // So for some predicated instruction, e.g. the conditional sdiv in: + // + // for.body: + // ... + // %add = add nsw i32 %mul, %0 + // %cmp5 = icmp sgt i32 %2, 7 + // br i1 %cmp5, label %if.then, label %if.end + // + // if.then: + // %div = sdiv i32 %0, %1 + // br label %if.end + // + // if.end: + // %x.0 = phi i32 [ %div, %if.then ], [ %add, %for.body ] + // + // the sdiv at this point is scalarized and if-converted using a select. + // The inactive elements in the vector are not used, but the predicated + // instruction is still executed for all vector elements, essentially: + // + // vector.body: + // ... + // %17 = add nsw <2 x i32> %16, %wide.load + // %29 = extractelement <2 x i32> %wide.load, i32 0 + // %30 = extractelement <2 x i32> %wide.load51, i32 0 + // %31 = sdiv i32 %29, %30 + // %32 = insertelement <2 x i32> undef, i32 %31, i32 0 + // %35 = extractelement <2 x i32> %wide.load, i32 1 + // %36 = extractelement <2 x i32> %wide.load51, i32 1 + // %37 = sdiv i32 %35, %36 + // %38 = insertelement <2 x i32> %32, i32 %37, i32 1 + // %predphi = select <2 x i1> %26, <2 x i32> %38, <2 x i32> %17 + // + // Predication will now re-introduce the original control flow to avoid false + // side-effects by the sdiv instructions on the inactive elements, yielding + // (after cleanup): + // + // vector.body: + // ... + // %5 = add nsw <2 x i32> %4, %wide.load + // %8 = icmp sgt <2 x i32> %wide.load52, <i32 7, i32 7> + // %9 = extractelement <2 x i1> %8, i32 0 + // br i1 %9, label %pred.sdiv.if, label %pred.sdiv.continue + // + // pred.sdiv.if: + // %10 = extractelement <2 x i32> %wide.load, i32 0 + // %11 = extractelement <2 x i32> %wide.load51, i32 0 + // %12 = sdiv i32 %10, %11 + // %13 = insertelement <2 x i32> undef, i32 %12, i32 0 + // br label %pred.sdiv.continue + // + // pred.sdiv.continue: + // %14 = phi <2 x i32> [ undef, %vector.body ], [ %13, %pred.sdiv.if ] + // %15 = extractelement <2 x i1> %8, i32 1 + // br i1 %15, label %pred.sdiv.if54, label %pred.sdiv.continue55 + // + // pred.sdiv.if54: + // %16 = extractelement <2 x i32> %wide.load, i32 1 + // %17 = extractelement <2 x i32> %wide.load51, i32 1 + // %18 = sdiv i32 %16, %17 + // %19 = insertelement <2 x i32> %14, i32 %18, i32 1 + // br label %pred.sdiv.continue55 + // + // pred.sdiv.continue55: + // %20 = phi <2 x i32> [ %14, %pred.sdiv.continue ], [ %19, %pred.sdiv.if54 ] + // %predphi = select <2 x i1> %8, <2 x i32> %20, <2 x i32> %5 + + for (auto KV : PredicatedInstructions) { + BasicBlock::iterator I(KV.first); + BasicBlock *Head = I->getParent(); + auto *BB = SplitBlock(Head, &*std::next(I), DT, LI); + auto *T = SplitBlockAndInsertIfThen(KV.second, &*I, /*Unreachable=*/false, + /*BranchWeights=*/nullptr, DT, LI); + I->moveBefore(T); + sinkScalarOperands(&*I); + + I->getParent()->setName(Twine("pred.") + I->getOpcodeName() + ".if"); + BB->setName(Twine("pred.") + I->getOpcodeName() + ".continue"); + + // If the instruction is non-void create a Phi node at reconvergence point. + if (!I->getType()->isVoidTy()) { + Value *IncomingTrue = nullptr; + Value *IncomingFalse = nullptr; + + if (I->hasOneUse() && isa<InsertElementInst>(*I->user_begin())) { + // If the predicated instruction is feeding an insert-element, move it + // into the Then block; Phi node will be created for the vector. + InsertElementInst *IEI = cast<InsertElementInst>(*I->user_begin()); + IEI->moveBefore(T); + IncomingTrue = IEI; // the new vector with the inserted element. + IncomingFalse = IEI->getOperand(0); // the unmodified vector + } else { + // Phi node will be created for the scalar predicated instruction. + IncomingTrue = &*I; + IncomingFalse = UndefValue::get(I->getType()); + } + + BasicBlock *PostDom = I->getParent()->getSingleSuccessor(); + assert(PostDom && "Then block has multiple successors"); + PHINode *Phi = + PHINode::Create(IncomingTrue->getType(), 2, "", &PostDom->front()); + IncomingTrue->replaceAllUsesWith(Phi); + Phi->addIncoming(IncomingFalse, Head); + Phi->addIncoming(IncomingTrue, I->getParent()); + } + } + + DEBUG(DT->verifyDomTree()); +} + InnerLoopVectorizer::VectorParts InnerLoopVectorizer::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) { - assert(std::find(pred_begin(Dst), pred_end(Dst), Src) != pred_end(Dst) && - "Invalid edge"); + assert(is_contained(predecessors(Dst), Src) && "Invalid edge"); // Look for cached value. std::pair<BasicBlock *, BasicBlock *> Edge(Src, Dst); @@ -4033,12 +4562,12 @@ InnerLoopVectorizer::createBlockInMask(BasicBlock *BB) { return BlockMask; } -void InnerLoopVectorizer::widenPHIInstruction( - Instruction *PN, InnerLoopVectorizer::VectorParts &Entry, unsigned UF, - unsigned VF, PhiVector *PV) { +void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF, + unsigned VF, PhiVector *PV) { PHINode *P = cast<PHINode>(PN); // Handle recurrences. if (Legal->isReductionVariable(P) || Legal->isFirstOrderRecurrence(P)) { + VectorParts Entry(UF); for (unsigned part = 0; part < UF; ++part) { // This is phase one of vectorizing PHIs. Type *VecTy = @@ -4046,6 +4575,7 @@ void InnerLoopVectorizer::widenPHIInstruction( Entry[part] = PHINode::Create( VecTy, 2, "vec.phi", &*LoopVectorBody->getFirstInsertionPt()); } + VectorLoopValueMap.initVector(P, Entry); PV->push_back(P); return; } @@ -4066,10 +4596,11 @@ void InnerLoopVectorizer::widenPHIInstruction( // SELECT(Mask3, In3, // SELECT(Mask2, In2, // ( ...))) + VectorParts Entry(UF); for (unsigned In = 0; In < NumIncoming; In++) { VectorParts Cond = createEdgeMask(P->getIncomingBlock(In), P->getParent()); - VectorParts &In0 = getVectorValue(P->getIncomingValue(In)); + const VectorParts &In0 = getVectorValue(P->getIncomingValue(In)); for (unsigned part = 0; part < UF; ++part) { // We might have single edge PHIs (blocks) - use an identity @@ -4083,6 +4614,7 @@ void InnerLoopVectorizer::widenPHIInstruction( "predphi"); } } + VectorLoopValueMap.initVector(P, Entry); return; } @@ -4099,46 +4631,95 @@ void InnerLoopVectorizer::widenPHIInstruction( case InductionDescriptor::IK_NoInduction: llvm_unreachable("Unknown induction"); case InductionDescriptor::IK_IntInduction: - return widenIntInduction(P, Entry); - case InductionDescriptor::IK_PtrInduction: + return widenIntInduction(P); + case InductionDescriptor::IK_PtrInduction: { // Handle the pointer induction variable case. assert(P->getType()->isPointerTy() && "Unexpected type."); // This is the normalized GEP that starts counting at zero. Value *PtrInd = Induction; PtrInd = Builder.CreateSExtOrTrunc(PtrInd, II.getStep()->getType()); - // This is the vector of results. Notice that we don't generate - // vector geps because scalar geps result in better code. - for (unsigned part = 0; part < UF; ++part) { - if (VF == 1) { - int EltIndex = part; - Constant *Idx = ConstantInt::get(PtrInd->getType(), EltIndex); - Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); - Value *SclrGep = II.transform(Builder, GlobalIdx, PSE.getSE(), DL); - SclrGep->setName("next.gep"); - Entry[part] = SclrGep; - continue; - } - - Value *VecVal = UndefValue::get(VectorType::get(P->getType(), VF)); - for (unsigned int i = 0; i < VF; ++i) { - int EltIndex = i + part * VF; - Constant *Idx = ConstantInt::get(PtrInd->getType(), EltIndex); + // Determine the number of scalars we need to generate for each unroll + // iteration. If the instruction is uniform, we only need to generate the + // first lane. Otherwise, we generate all VF values. + unsigned Lanes = Legal->isUniformAfterVectorization(P) ? 1 : VF; + // These are the scalar results. Notice that we don't generate vector GEPs + // because scalar GEPs result in better code. + ScalarParts Entry(UF); + for (unsigned Part = 0; Part < UF; ++Part) { + Entry[Part].resize(VF); + for (unsigned Lane = 0; Lane < Lanes; ++Lane) { + Constant *Idx = ConstantInt::get(PtrInd->getType(), Lane + Part * VF); Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); Value *SclrGep = II.transform(Builder, GlobalIdx, PSE.getSE(), DL); SclrGep->setName("next.gep"); - VecVal = Builder.CreateInsertElement(VecVal, SclrGep, - Builder.getInt32(i), "insert.gep"); + Entry[Part][Lane] = SclrGep; } - Entry[part] = VecVal; } + VectorLoopValueMap.initScalar(P, Entry); return; } + case InductionDescriptor::IK_FpInduction: { + assert(P->getType() == II.getStartValue()->getType() && + "Types must match"); + // Handle other induction variables that are now based on the + // canonical one. + assert(P != OldInduction && "Primary induction can be integer only"); + + Value *V = Builder.CreateCast(Instruction::SIToFP, Induction, P->getType()); + V = II.transform(Builder, V, PSE.getSE(), DL); + V->setName("fp.offset.idx"); + + // Now we have scalar op: %fp.offset.idx = StartVal +/- Induction*StepVal + + Value *Broadcasted = getBroadcastInstrs(V); + // After broadcasting the induction variable we need to make the vector + // consecutive by adding StepVal*0, StepVal*1, StepVal*2, etc. + Value *StepVal = cast<SCEVUnknown>(II.getStep())->getValue(); + VectorParts Entry(UF); + for (unsigned part = 0; part < UF; ++part) + Entry[part] = getStepVector(Broadcasted, VF * part, StepVal, + II.getInductionOpcode()); + VectorLoopValueMap.initVector(P, Entry); + return; + } + } +} + +/// A helper function for checking whether an integer division-related +/// instruction may divide by zero (in which case it must be predicated if +/// executed conditionally in the scalar code). +/// TODO: It may be worthwhile to generalize and check isKnownNonZero(). +/// Non-zero divisors that are non compile-time constants will not be +/// converted into multiplication, so we will still end up scalarizing +/// the division, but can do so w/o predication. +static bool mayDivideByZero(Instruction &I) { + assert((I.getOpcode() == Instruction::UDiv || + I.getOpcode() == Instruction::SDiv || + I.getOpcode() == Instruction::URem || + I.getOpcode() == Instruction::SRem) && + "Unexpected instruction"); + Value *Divisor = I.getOperand(1); + auto *CInt = dyn_cast<ConstantInt>(Divisor); + return !CInt || CInt->isZero(); } void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { // For each instruction in the old loop. for (Instruction &I : *BB) { - VectorParts &Entry = WidenMap.get(&I); + + // If the instruction will become trivially dead when vectorized, we don't + // need to generate it. + if (DeadInstructions.count(&I)) + continue; + + // Scalarize instructions that should remain scalar after vectorization. + if (VF > 1 && + !(isa<BranchInst>(&I) || isa<PHINode>(&I) || + isa<DbgInfoIntrinsic>(&I)) && + shouldScalarizeInstruction(&I)) { + scalarizeInstruction(&I, Legal->isScalarWithPredication(&I)); + continue; + } switch (I.getOpcode()) { case Instruction::Br: @@ -4147,21 +4728,27 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { continue; case Instruction::PHI: { // Vectorize PHINodes. - widenPHIInstruction(&I, Entry, UF, VF, PV); + widenPHIInstruction(&I, UF, VF, PV); continue; } // End of PHI. + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::SRem: + case Instruction::URem: + // Scalarize with predication if this instruction may divide by zero and + // block execution is conditional, otherwise fallthrough. + if (Legal->isScalarWithPredication(&I)) { + scalarizeInstruction(&I, true); + continue; + } case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: case Instruction::FSub: case Instruction::Mul: case Instruction::FMul: - case Instruction::UDiv: - case Instruction::SDiv: case Instruction::FDiv: - case Instruction::URem: - case Instruction::SRem: case Instruction::FRem: case Instruction::Shl: case Instruction::LShr: @@ -4172,10 +4759,11 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { // Just widen binops. auto *BinOp = cast<BinaryOperator>(&I); setDebugLocFromInst(Builder, BinOp); - VectorParts &A = getVectorValue(BinOp->getOperand(0)); - VectorParts &B = getVectorValue(BinOp->getOperand(1)); + const VectorParts &A = getVectorValue(BinOp->getOperand(0)); + const VectorParts &B = getVectorValue(BinOp->getOperand(1)); // Use this vector value for all users of the original instruction. + VectorParts Entry(UF); for (unsigned Part = 0; Part < UF; ++Part) { Value *V = Builder.CreateBinOp(BinOp->getOpcode(), A[Part], B[Part]); @@ -4185,6 +4773,7 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { Entry[Part] = V; } + VectorLoopValueMap.initVector(&I, Entry); addMetadata(Entry, BinOp); break; } @@ -4201,20 +4790,19 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { // loop. This means that we can't just use the original 'cond' value. // We have to take the 'vectorized' value and pick the first lane. // Instcombine will make this a no-op. - VectorParts &Cond = getVectorValue(I.getOperand(0)); - VectorParts &Op0 = getVectorValue(I.getOperand(1)); - VectorParts &Op1 = getVectorValue(I.getOperand(2)); + const VectorParts &Cond = getVectorValue(I.getOperand(0)); + const VectorParts &Op0 = getVectorValue(I.getOperand(1)); + const VectorParts &Op1 = getVectorValue(I.getOperand(2)); - Value *ScalarCond = - (VF == 1) - ? Cond[0] - : Builder.CreateExtractElement(Cond[0], Builder.getInt32(0)); + auto *ScalarCond = getScalarValue(I.getOperand(0), 0, 0); + VectorParts Entry(UF); for (unsigned Part = 0; Part < UF; ++Part) { Entry[Part] = Builder.CreateSelect( InvariantCond ? ScalarCond : Cond[Part], Op0[Part], Op1[Part]); } + VectorLoopValueMap.initVector(&I, Entry); addMetadata(Entry, &I); break; } @@ -4225,8 +4813,9 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { bool FCmp = (I.getOpcode() == Instruction::FCmp); auto *Cmp = dyn_cast<CmpInst>(&I); setDebugLocFromInst(Builder, Cmp); - VectorParts &A = getVectorValue(Cmp->getOperand(0)); - VectorParts &B = getVectorValue(Cmp->getOperand(1)); + const VectorParts &A = getVectorValue(Cmp->getOperand(0)); + const VectorParts &B = getVectorValue(Cmp->getOperand(1)); + VectorParts Entry(UF); for (unsigned Part = 0; Part < UF; ++Part) { Value *C = nullptr; if (FCmp) { @@ -4238,6 +4827,7 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { Entry[Part] = C; } + VectorLoopValueMap.initVector(&I, Entry); addMetadata(Entry, &I); break; } @@ -4268,8 +4858,7 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { auto ID = Legal->getInductionVars()->lookup(OldInduction); if (isa<TruncInst>(CI) && CI->getOperand(0) == OldInduction && ID.getConstIntStepValue()) { - widenIntInduction(OldInduction, Entry, cast<TruncInst>(CI)); - addMetadata(Entry, &I); + widenIntInduction(OldInduction, cast<TruncInst>(CI)); break; } @@ -4277,9 +4866,11 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { Type *DestTy = (VF == 1) ? CI->getType() : VectorType::get(CI->getType(), VF); - VectorParts &A = getVectorValue(CI->getOperand(0)); + const VectorParts &A = getVectorValue(CI->getOperand(0)); + VectorParts Entry(UF); for (unsigned Part = 0; Part < UF; ++Part) Entry[Part] = Builder.CreateCast(CI->getOpcode(), A[Part], DestTy); + VectorLoopValueMap.initVector(&I, Entry); addMetadata(Entry, &I); break; } @@ -4318,6 +4909,7 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { break; } + VectorParts Entry(UF); for (unsigned Part = 0; Part < UF; ++Part) { SmallVector<Value *, 4> Args; for (unsigned i = 0, ie = CI->getNumArgOperands(); i != ie; ++i) { @@ -4325,7 +4917,7 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { // Some intrinsics have a scalar argument - don't replace it with a // vector. if (!UseVectorIntrinsic || !hasVectorInstrinsicScalarOpd(ID, i)) { - VectorParts &VectorArg = getVectorValue(CI->getArgOperand(i)); + const VectorParts &VectorArg = getVectorValue(CI->getArgOperand(i)); Arg = VectorArg[Part]; } Args.push_back(Arg); @@ -4363,6 +4955,7 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { Entry[Part] = V; } + VectorLoopValueMap.initVector(&I, Entry); addMetadata(Entry, &I); break; } @@ -4414,7 +5007,8 @@ static bool canIfConvertPHINodes(BasicBlock *BB) { bool LoopVectorizationLegality::canVectorizeWithIfConvert() { if (!EnableIfConversion) { - emitAnalysis(VectorizationReport() << "if-conversion is disabled"); + ORE->emit(createMissedAnalysis("IfConversionDisabled") + << "if-conversion is disabled"); return false; } @@ -4428,12 +5022,9 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() { if (blockNeedsPredication(BB)) continue; - for (Instruction &I : *BB) { - if (auto *LI = dyn_cast<LoadInst>(&I)) - SafePointes.insert(LI->getPointerOperand()); - else if (auto *SI = dyn_cast<StoreInst>(&I)) - SafePointes.insert(SI->getPointerOperand()); - } + for (Instruction &I : *BB) + if (auto *Ptr = getPointerOperand(&I)) + SafePointes.insert(Ptr); } // Collect the blocks that need predication. @@ -4441,21 +5032,21 @@ bool LoopVectorizationLegality::canVectorizeWithIfConvert() { for (BasicBlock *BB : TheLoop->blocks()) { // We don't support switch statements inside loops. if (!isa<BranchInst>(BB->getTerminator())) { - emitAnalysis(VectorizationReport(BB->getTerminator()) - << "loop contains a switch statement"); + ORE->emit(createMissedAnalysis("LoopContainsSwitch", BB->getTerminator()) + << "loop contains a switch statement"); return false; } // We must be able to predicate all blocks that need to be predicated. if (blockNeedsPredication(BB)) { if (!blockCanBePredicated(BB, SafePointes)) { - emitAnalysis(VectorizationReport(BB->getTerminator()) - << "control flow cannot be substituted for a select"); + ORE->emit(createMissedAnalysis("NoCFGForSelect", BB->getTerminator()) + << "control flow cannot be substituted for a select"); return false; } } else if (BB != Header && !canIfConvertPHINodes(BB)) { - emitAnalysis(VectorizationReport(BB->getTerminator()) - << "control flow cannot be substituted for a select"); + ORE->emit(createMissedAnalysis("NoCFGForSelect", BB->getTerminator()) + << "control flow cannot be substituted for a select"); return false; } } @@ -4468,8 +5059,8 @@ bool LoopVectorizationLegality::canVectorize() { // We must have a loop in canonical form. Loops with indirectbr in them cannot // be canonicalized. if (!TheLoop->getLoopPreheader()) { - emitAnalysis(VectorizationReport() - << "loop control flow is not understood by vectorizer"); + ORE->emit(createMissedAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by vectorizer"); return false; } @@ -4478,21 +5069,22 @@ bool LoopVectorizationLegality::canVectorize() { // // We can only vectorize innermost loops. if (!TheLoop->empty()) { - emitAnalysis(VectorizationReport() << "loop is not the innermost loop"); + ORE->emit(createMissedAnalysis("NotInnermostLoop") + << "loop is not the innermost loop"); return false; } // We must have a single backedge. if (TheLoop->getNumBackEdges() != 1) { - emitAnalysis(VectorizationReport() - << "loop control flow is not understood by vectorizer"); + ORE->emit(createMissedAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by vectorizer"); return false; } // We must have a single exiting block. if (!TheLoop->getExitingBlock()) { - emitAnalysis(VectorizationReport() - << "loop control flow is not understood by vectorizer"); + ORE->emit(createMissedAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by vectorizer"); return false; } @@ -4500,8 +5092,8 @@ bool LoopVectorizationLegality::canVectorize() { // checked at the end of each iteration. With that we can assume that all // instructions in the loop are executed the same number of times. if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) { - emitAnalysis(VectorizationReport() - << "loop control flow is not understood by vectorizer"); + ORE->emit(createMissedAnalysis("CFGNotUnderstood") + << "loop control flow is not understood by vectorizer"); return false; } @@ -4519,8 +5111,8 @@ bool LoopVectorizationLegality::canVectorize() { // ScalarEvolution needs to be able to find the exit count. const SCEV *ExitCount = PSE.getBackedgeTakenCount(); if (ExitCount == PSE.getSE()->getCouldNotCompute()) { - emitAnalysis(VectorizationReport() - << "could not determine number of loop iterations"); + ORE->emit(createMissedAnalysis("CantComputeNumberOfIterations") + << "could not determine number of loop iterations"); DEBUG(dbgs() << "LV: SCEV could not compute the loop exit count.\n"); return false; } @@ -4537,9 +5129,6 @@ bool LoopVectorizationLegality::canVectorize() { return false; } - // Collect all of the variables that remain uniform after vectorization. - collectLoopUniforms(); - DEBUG(dbgs() << "LV: We can vectorize this loop" << (LAI->getRuntimePointerChecking()->Need ? " (with a runtime bound check)" @@ -4556,14 +5145,20 @@ bool LoopVectorizationLegality::canVectorize() { if (UseInterleaved) InterleaveInfo.analyzeInterleaving(*getSymbolicStrides()); + // Collect all instructions that are known to be uniform after vectorization. + collectLoopUniforms(); + + // Collect all instructions that are known to be scalar after vectorization. + collectLoopScalars(); + unsigned SCEVThreshold = VectorizeSCEVCheckThreshold; if (Hints->getForce() == LoopVectorizeHints::FK_Enabled) SCEVThreshold = PragmaVectorizeSCEVCheckThreshold; if (PSE.getUnionPredicate().getComplexity() > SCEVThreshold) { - emitAnalysis(VectorizationReport() - << "Too many SCEV assumptions need to be made and checked " - << "at runtime"); + ORE->emit(createMissedAnalysis("TooManySCEVRunTimeChecks") + << "Too many SCEV assumptions need to be made and checked " + << "at runtime"); DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n"); return false; } @@ -4621,10 +5216,12 @@ void LoopVectorizationLegality::addInductionPhi( const DataLayout &DL = Phi->getModule()->getDataLayout(); // Get the widest type. - if (!WidestIndTy) - WidestIndTy = convertPointerToIntegerType(DL, PhiTy); - else - WidestIndTy = getWiderType(DL, PhiTy, WidestIndTy); + if (!PhiTy->isFloatingPointTy()) { + if (!WidestIndTy) + WidestIndTy = convertPointerToIntegerType(DL, PhiTy); + else + WidestIndTy = getWiderType(DL, PhiTy, WidestIndTy); + } // Int inductions are special because we only allow one IV. if (ID.getKind() == InductionDescriptor::IK_IntInduction && @@ -4667,8 +5264,8 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // Check that this PHI type is allowed. if (!PhiTy->isIntegerTy() && !PhiTy->isFloatingPointTy() && !PhiTy->isPointerTy()) { - emitAnalysis(VectorizationReport(Phi) - << "loop control flow is not understood by vectorizer"); + ORE->emit(createMissedAnalysis("CFGNotUnderstood", Phi) + << "loop control flow is not understood by vectorizer"); DEBUG(dbgs() << "LV: Found an non-int non-pointer PHI.\n"); return false; } @@ -4681,16 +5278,16 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // identified reduction value with an outside user. if (!hasOutsideLoopUser(TheLoop, Phi, AllowedExit)) continue; - emitAnalysis(VectorizationReport(Phi) - << "value could not be identified as " - "an induction or reduction variable"); + ORE->emit(createMissedAnalysis("NeitherInductionNorReduction", Phi) + << "value could not be identified as " + "an induction or reduction variable"); return false; } // We only allow if-converted PHIs with exactly two incoming values. if (Phi->getNumIncomingValues() != 2) { - emitAnalysis(VectorizationReport(Phi) - << "control flow not understood by vectorizer"); + ORE->emit(createMissedAnalysis("CFGNotUnderstood", Phi) + << "control flow not understood by vectorizer"); DEBUG(dbgs() << "LV: Found an invalid PHI.\n"); return false; } @@ -4705,8 +5302,10 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { } InductionDescriptor ID; - if (InductionDescriptor::isInductionPHI(Phi, PSE, ID)) { + if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID)) { addInductionPhi(Phi, ID, AllowedExit); + if (ID.hasUnsafeAlgebra() && !HasFunNoNaNAttr) + Requirements->addUnsafeAlgebraInst(ID.getUnsafeAlgebraInst()); continue; } @@ -4717,14 +5316,14 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // As a last resort, coerce the PHI to a AddRec expression // and re-try classifying it a an induction PHI. - if (InductionDescriptor::isInductionPHI(Phi, PSE, ID, true)) { + if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID, true)) { addInductionPhi(Phi, ID, AllowedExit); continue; } - emitAnalysis(VectorizationReport(Phi) - << "value that could not be identified as " - "reduction is used outside the loop"); + ORE->emit(createMissedAnalysis("NonReductionValueUsedOutsideLoop", Phi) + << "value that could not be identified as " + "reduction is used outside the loop"); DEBUG(dbgs() << "LV: Found an unidentified PHI." << *Phi << "\n"); return false; } // end of PHI handling @@ -4738,8 +5337,8 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { !isa<DbgInfoIntrinsic>(CI) && !(CI->getCalledFunction() && TLI && TLI->isFunctionVectorizable(CI->getCalledFunction()->getName()))) { - emitAnalysis(VectorizationReport(CI) - << "call instruction cannot be vectorized"); + ORE->emit(createMissedAnalysis("CantVectorizeCall", CI) + << "call instruction cannot be vectorized"); DEBUG(dbgs() << "LV: Found a non-intrinsic, non-libfunc callsite.\n"); return false; } @@ -4750,8 +5349,8 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { getVectorIntrinsicIDForCall(CI, TLI), 1)) { auto *SE = PSE.getSE(); if (!SE->isLoopInvariant(PSE.getSCEV(CI->getOperand(1)), TheLoop)) { - emitAnalysis(VectorizationReport(CI) - << "intrinsic instruction cannot be vectorized"); + ORE->emit(createMissedAnalysis("CantVectorizeIntrinsic", CI) + << "intrinsic instruction cannot be vectorized"); DEBUG(dbgs() << "LV: Found unvectorizable intrinsic " << *CI << "\n"); return false; } @@ -4762,8 +5361,8 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { if ((!VectorType::isValidElementType(I.getType()) && !I.getType()->isVoidTy()) || isa<ExtractElementInst>(I)) { - emitAnalysis(VectorizationReport(&I) - << "instruction return type cannot be vectorized"); + ORE->emit(createMissedAnalysis("CantVectorizeInstructionReturnType", &I) + << "instruction return type cannot be vectorized"); DEBUG(dbgs() << "LV: Found unvectorizable type.\n"); return false; } @@ -4772,8 +5371,8 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { if (auto *ST = dyn_cast<StoreInst>(&I)) { Type *T = ST->getValueOperand()->getType(); if (!VectorType::isValidElementType(T)) { - emitAnalysis(VectorizationReport(ST) - << "store instruction cannot be vectorized"); + ORE->emit(createMissedAnalysis("CantVectorizeStore", ST) + << "store instruction cannot be vectorized"); return false; } @@ -4791,8 +5390,8 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // Reduction instructions are allowed to have exit users. // All other instructions must not have external users. if (hasOutsideLoopUser(TheLoop, &I, AllowedExit)) { - emitAnalysis(VectorizationReport(&I) - << "value cannot be used outside the loop"); + ORE->emit(createMissedAnalysis("ValueUsedOutsideLoop", &I) + << "value cannot be used outside the loop"); return false; } @@ -4802,8 +5401,8 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { if (!Induction) { DEBUG(dbgs() << "LV: Did not find one integer induction var.\n"); if (Inductions.empty()) { - emitAnalysis(VectorizationReport() - << "loop induction variable could not be identified"); + ORE->emit(createMissedAnalysis("NoInductionVariable") + << "loop induction variable could not be identified"); return false; } } @@ -4817,12 +5416,132 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { return true; } +void LoopVectorizationLegality::collectLoopScalars() { + + // If an instruction is uniform after vectorization, it will remain scalar. + Scalars.insert(Uniforms.begin(), Uniforms.end()); + + // Collect the getelementptr instructions that will not be vectorized. A + // getelementptr instruction is only vectorized if it is used for a legal + // gather or scatter operation. + for (auto *BB : TheLoop->blocks()) + for (auto &I : *BB) { + if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { + Scalars.insert(GEP); + continue; + } + auto *Ptr = getPointerOperand(&I); + if (!Ptr) + continue; + auto *GEP = getGEPInstruction(Ptr); + if (GEP && isLegalGatherOrScatter(&I)) + Scalars.erase(GEP); + } + + // An induction variable will remain scalar if all users of the induction + // variable and induction variable update remain scalar. + auto *Latch = TheLoop->getLoopLatch(); + for (auto &Induction : *getInductionVars()) { + auto *Ind = Induction.first; + auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); + + // Determine if all users of the induction variable are scalar after + // vectorization. + auto ScalarInd = all_of(Ind->users(), [&](User *U) -> bool { + auto *I = cast<Instruction>(U); + return I == IndUpdate || !TheLoop->contains(I) || Scalars.count(I); + }); + if (!ScalarInd) + continue; + + // Determine if all users of the induction variable update instruction are + // scalar after vectorization. + auto ScalarIndUpdate = all_of(IndUpdate->users(), [&](User *U) -> bool { + auto *I = cast<Instruction>(U); + return I == Ind || !TheLoop->contains(I) || Scalars.count(I); + }); + if (!ScalarIndUpdate) + continue; + + // The induction variable and its update instruction will remain scalar. + Scalars.insert(Ind); + Scalars.insert(IndUpdate); + } +} + +bool LoopVectorizationLegality::hasConsecutiveLikePtrOperand(Instruction *I) { + if (isAccessInterleaved(I)) + return true; + if (auto *Ptr = getPointerOperand(I)) + return isConsecutivePtr(Ptr); + return false; +} + +bool LoopVectorizationLegality::isScalarWithPredication(Instruction *I) { + if (!blockNeedsPredication(I->getParent())) + return false; + switch(I->getOpcode()) { + default: + break; + case Instruction::Store: + return !isMaskRequired(I); + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::SRem: + case Instruction::URem: + return mayDivideByZero(*I); + } + return false; +} + +bool LoopVectorizationLegality::memoryInstructionMustBeScalarized( + Instruction *I, unsigned VF) { + + // If the memory instruction is in an interleaved group, it will be + // vectorized and its pointer will remain uniform. + if (isAccessInterleaved(I)) + return false; + + // Get and ensure we have a valid memory instruction. + LoadInst *LI = dyn_cast<LoadInst>(I); + StoreInst *SI = dyn_cast<StoreInst>(I); + assert((LI || SI) && "Invalid memory instruction"); + + // If the pointer operand is uniform (loop invariant), the memory instruction + // will be scalarized. + auto *Ptr = getPointerOperand(I); + if (LI && isUniform(Ptr)) + return true; + + // If the pointer operand is non-consecutive and neither a gather nor a + // scatter operation is legal, the memory instruction will be scalarized. + if (!isConsecutivePtr(Ptr) && !isLegalGatherOrScatter(I)) + return true; + + // If the instruction is a store located in a predicated block, it will be + // scalarized. + if (isScalarWithPredication(I)) + return true; + + // If the instruction's allocated size doesn't equal it's type size, it + // requires padding and will be scalarized. + auto &DL = I->getModule()->getDataLayout(); + auto *ScalarTy = LI ? LI->getType() : SI->getValueOperand()->getType(); + if (hasIrregularType(ScalarTy, DL, VF)) + return true; + + // Otherwise, the memory instruction should be vectorized if the rest of the + // loop is. + return false; +} + void LoopVectorizationLegality::collectLoopUniforms() { // We now know that the loop is vectorizable! - // Collect variables that will remain uniform after vectorization. + // Collect instructions inside the loop that will remain uniform after + // vectorization. - // If V is not an instruction inside the current loop, it is a Value - // outside of the scope which we are interesting in. + // Global values, params and instructions outside of current loop are out of + // scope. auto isOutOfScope = [&](Value *V) -> bool { Instruction *I = dyn_cast<Instruction>(V); return (!I || !TheLoop->contains(I)); @@ -4830,30 +5549,82 @@ void LoopVectorizationLegality::collectLoopUniforms() { SetVector<Instruction *> Worklist; BasicBlock *Latch = TheLoop->getLoopLatch(); - // Start with the conditional branch. - if (!isOutOfScope(Latch->getTerminator()->getOperand(0))) { - Instruction *Cmp = cast<Instruction>(Latch->getTerminator()->getOperand(0)); + + // Start with the conditional branch. If the branch condition is an + // instruction contained in the loop that is only used by the branch, it is + // uniform. + auto *Cmp = dyn_cast<Instruction>(Latch->getTerminator()->getOperand(0)); + if (Cmp && TheLoop->contains(Cmp) && Cmp->hasOneUse()) { Worklist.insert(Cmp); DEBUG(dbgs() << "LV: Found uniform instruction: " << *Cmp << "\n"); } - // Also add all consecutive pointer values; these values will be uniform - // after vectorization (and subsequent cleanup). - for (auto *BB : TheLoop->blocks()) { + // Holds consecutive and consecutive-like pointers. Consecutive-like pointers + // are pointers that are treated like consecutive pointers during + // vectorization. The pointer operands of interleaved accesses are an + // example. + SmallSetVector<Instruction *, 8> ConsecutiveLikePtrs; + + // Holds pointer operands of instructions that are possibly non-uniform. + SmallPtrSet<Instruction *, 8> PossibleNonUniformPtrs; + + // Iterate over the instructions in the loop, and collect all + // consecutive-like pointer operands in ConsecutiveLikePtrs. If it's possible + // that a consecutive-like pointer operand will be scalarized, we collect it + // in PossibleNonUniformPtrs instead. We use two sets here because a single + // getelementptr instruction can be used by both vectorized and scalarized + // memory instructions. For example, if a loop loads and stores from the same + // location, but the store is conditional, the store will be scalarized, and + // the getelementptr won't remain uniform. + for (auto *BB : TheLoop->blocks()) for (auto &I : *BB) { - if (I.getType()->isPointerTy() && isConsecutivePtr(&I)) { - Worklist.insert(&I); - DEBUG(dbgs() << "LV: Found uniform instruction: " << I << "\n"); - } + + // If there's no pointer operand, there's nothing to do. + auto *Ptr = dyn_cast_or_null<Instruction>(getPointerOperand(&I)); + if (!Ptr) + continue; + + // True if all users of Ptr are memory accesses that have Ptr as their + // pointer operand. + auto UsersAreMemAccesses = all_of(Ptr->users(), [&](User *U) -> bool { + return getPointerOperand(U) == Ptr; + }); + + // Ensure the memory instruction will not be scalarized, making its + // pointer operand non-uniform. If the pointer operand is used by some + // instruction other than a memory access, we're not going to check if + // that other instruction may be scalarized here. Thus, conservatively + // assume the pointer operand may be non-uniform. + if (!UsersAreMemAccesses || memoryInstructionMustBeScalarized(&I)) + PossibleNonUniformPtrs.insert(Ptr); + + // If the memory instruction will be vectorized and its pointer operand + // is consecutive-like, the pointer operand should remain uniform. + else if (hasConsecutiveLikePtrOperand(&I)) + ConsecutiveLikePtrs.insert(Ptr); + + // Otherwise, if the memory instruction will be vectorized and its + // pointer operand is non-consecutive-like, the memory instruction should + // be a gather or scatter operation. Its pointer operand will be + // non-uniform. + else + PossibleNonUniformPtrs.insert(Ptr); + } + + // Add to the Worklist all consecutive and consecutive-like pointers that + // aren't also identified as possibly non-uniform. + for (auto *V : ConsecutiveLikePtrs) + if (!PossibleNonUniformPtrs.count(V)) { + DEBUG(dbgs() << "LV: Found uniform instruction: " << *V << "\n"); + Worklist.insert(V); } - } // Expand Worklist in topological order: whenever a new instruction // is added , its users should be either already inside Worklist, or // out of scope. It ensures a uniform instruction will only be used // by uniform instructions or out of scope instructions. unsigned idx = 0; - do { + while (idx != Worklist.size()) { Instruction *I = Worklist[idx++]; for (auto OV : I->operand_values()) { @@ -4867,32 +5638,49 @@ void LoopVectorizationLegality::collectLoopUniforms() { DEBUG(dbgs() << "LV: Found uniform instruction: " << *OI << "\n"); } } - } while (idx != Worklist.size()); + } + + // Returns true if Ptr is the pointer operand of a memory access instruction + // I, and I is known to not require scalarization. + auto isVectorizedMemAccessUse = [&](Instruction *I, Value *Ptr) -> bool { + return getPointerOperand(I) == Ptr && !memoryInstructionMustBeScalarized(I); + }; // For an instruction to be added into Worklist above, all its users inside - // the current loop should be already added into Worklist. This condition - // cannot be true for phi instructions which is always in a dependence loop. - // Because any instruction in the dependence cycle always depends on others - // in the cycle to be added into Worklist first, the result is no ones in - // the cycle will be added into Worklist in the end. - // That is why we process PHI separately. - for (auto &Induction : *getInductionVars()) { - auto *PN = Induction.first; - auto *UpdateV = PN->getIncomingValueForBlock(TheLoop->getLoopLatch()); - if (all_of(PN->users(), - [&](User *U) -> bool { - return U == UpdateV || isOutOfScope(U) || - Worklist.count(cast<Instruction>(U)); - }) && - all_of(UpdateV->users(), [&](User *U) -> bool { - return U == PN || isOutOfScope(U) || - Worklist.count(cast<Instruction>(U)); - })) { - Worklist.insert(cast<Instruction>(PN)); - Worklist.insert(cast<Instruction>(UpdateV)); - DEBUG(dbgs() << "LV: Found uniform instruction: " << *PN << "\n"); - DEBUG(dbgs() << "LV: Found uniform instruction: " << *UpdateV << "\n"); - } + // the loop should also be in Worklist. However, this condition cannot be + // true for phi nodes that form a cyclic dependence. We must process phi + // nodes separately. An induction variable will remain uniform if all users + // of the induction variable and induction variable update remain uniform. + // The code below handles both pointer and non-pointer induction variables. + for (auto &Induction : Inductions) { + auto *Ind = Induction.first; + auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); + + // Determine if all users of the induction variable are uniform after + // vectorization. + auto UniformInd = all_of(Ind->users(), [&](User *U) -> bool { + auto *I = cast<Instruction>(U); + return I == IndUpdate || !TheLoop->contains(I) || Worklist.count(I) || + isVectorizedMemAccessUse(I, Ind); + }); + if (!UniformInd) + continue; + + // Determine if all users of the induction variable update instruction are + // uniform after vectorization. + auto UniformIndUpdate = all_of(IndUpdate->users(), [&](User *U) -> bool { + auto *I = cast<Instruction>(U); + return I == Ind || !TheLoop->contains(I) || Worklist.count(I) || + isVectorizedMemAccessUse(I, IndUpdate); + }); + if (!UniformIndUpdate) + continue; + + // The induction variable and its update instruction will remain uniform. + Worklist.insert(Ind); + Worklist.insert(IndUpdate); + DEBUG(dbgs() << "LV: Found uniform instruction: " << *Ind << "\n"); + DEBUG(dbgs() << "LV: Found uniform instruction: " << *IndUpdate << "\n"); } Uniforms.insert(Worklist.begin(), Worklist.end()); @@ -4901,16 +5689,18 @@ void LoopVectorizationLegality::collectLoopUniforms() { bool LoopVectorizationLegality::canVectorizeMemory() { LAI = &(*GetLAA)(*TheLoop); InterleaveInfo.setLAI(LAI); - auto &OptionalReport = LAI->getReport(); - if (OptionalReport) - emitAnalysis(VectorizationReport(*OptionalReport)); + const OptimizationRemarkAnalysis *LAR = LAI->getReport(); + if (LAR) { + OptimizationRemarkAnalysis VR(Hints->vectorizeAnalysisPassName(), + "loop not vectorized: ", *LAR); + ORE->emit(VR); + } if (!LAI->canVectorizeMemory()) return false; if (LAI->hasStoreToLoopInvariantAddress()) { - emitAnalysis( - VectorizationReport() - << "write to a loop invariant address could not be vectorized"); + ORE->emit(createMissedAnalysis("CantVectorizeStoreToLoopInvariantAddress") + << "write to a loop invariant address could not be vectorized"); DEBUG(dbgs() << "LV: We don't allow storing to uniform addresses\n"); return false; } @@ -4967,7 +5757,6 @@ bool LoopVectorizationLegality::blockCanBePredicated( } } - // We don't predicate stores at the moment. if (I.mayWriteToMemory()) { auto *SI = dyn_cast<StoreInst>(&I); // We only support predication of stores in basic blocks with one @@ -4992,17 +5781,6 @@ bool LoopVectorizationLegality::blockCanBePredicated( } if (I.mayThrow()) return false; - - // The instructions below can trap. - switch (I.getOpcode()) { - default: - continue; - case Instruction::UDiv: - case Instruction::SDiv: - case Instruction::URem: - case Instruction::SRem: - return false; - } } return true; @@ -5029,8 +5807,16 @@ void InterleavedAccessInfo::collectConstStrideAccesses( if (!LI && !SI) continue; - Value *Ptr = LI ? LI->getPointerOperand() : SI->getPointerOperand(); - int64_t Stride = getPtrStride(PSE, Ptr, TheLoop, Strides); + Value *Ptr = getPointerOperand(&I); + // We don't check wrapping here because we don't know yet if Ptr will be + // part of a full group or a group with gaps. Checking wrapping for all + // pointers (even those that end up in groups with no gaps) will be overly + // conservative. For full groups, wrapping should be ok since if we would + // wrap around the address space we would do a memory access at nullptr + // even without the transformation. The wrapping checks are therefore + // deferred until after we've formed the interleaved groups. + int64_t Stride = getPtrStride(PSE, Ptr, TheLoop, Strides, + /*Assume=*/true, /*ShouldCheckWrap=*/false); const SCEV *Scev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType()); @@ -5234,20 +6020,66 @@ void InterleavedAccessInfo::analyzeInterleaving( if (Group->getNumMembers() != Group->getFactor()) releaseGroup(Group); - // If there is a non-reversed interleaved load group with gaps, we will need - // to execute at least one scalar epilogue iteration. This will ensure that - // we don't speculatively access memory out-of-bounds. Note that we only need - // to look for a member at index factor - 1, since every group must have a - // member at index zero. - for (InterleaveGroup *Group : LoadGroups) - if (!Group->getMember(Group->getFactor() - 1)) { + // Remove interleaved groups with gaps (currently only loads) whose memory + // accesses may wrap around. We have to revisit the getPtrStride analysis, + // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does + // not check wrapping (see documentation there). + // FORNOW we use Assume=false; + // TODO: Change to Assume=true but making sure we don't exceed the threshold + // of runtime SCEV assumptions checks (thereby potentially failing to + // vectorize altogether). + // Additional optional optimizations: + // TODO: If we are peeling the loop and we know that the first pointer doesn't + // wrap then we can deduce that all pointers in the group don't wrap. + // This means that we can forcefully peel the loop in order to only have to + // check the first pointer for no-wrap. When we'll change to use Assume=true + // we'll only need at most one runtime check per interleaved group. + // + for (InterleaveGroup *Group : LoadGroups) { + + // Case 1: A full group. Can Skip the checks; For full groups, if the wide + // load would wrap around the address space we would do a memory access at + // nullptr even without the transformation. + if (Group->getNumMembers() == Group->getFactor()) + continue; + + // Case 2: If first and last members of the group don't wrap this implies + // that all the pointers in the group don't wrap. + // So we check only group member 0 (which is always guaranteed to exist), + // and group member Factor - 1; If the latter doesn't exist we rely on + // peeling (if it is a non-reveresed accsess -- see Case 3). + Value *FirstMemberPtr = getPointerOperand(Group->getMember(0)); + if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false, + /*ShouldCheckWrap=*/true)) { + DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to " + "first group member potentially pointer-wrapping.\n"); + releaseGroup(Group); + continue; + } + Instruction *LastMember = Group->getMember(Group->getFactor() - 1); + if (LastMember) { + Value *LastMemberPtr = getPointerOperand(LastMember); + if (!getPtrStride(PSE, LastMemberPtr, TheLoop, Strides, /*Assume=*/false, + /*ShouldCheckWrap=*/true)) { + DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to " + "last group member potentially pointer-wrapping.\n"); + releaseGroup(Group); + } + } + else { + // Case 3: A non-reversed interleaved load group with gaps: We need + // to execute at least one scalar epilogue iteration. This will ensure + // we don't speculatively access memory out-of-bounds. We only need + // to look for a member at index factor - 1, since every group must have + // a member at index zero. if (Group->isReverse()) { releaseGroup(Group); - } else { - DEBUG(dbgs() << "LV: Interleaved group requires epilogue iteration.\n"); - RequiresScalarEpilogue = true; + continue; } + DEBUG(dbgs() << "LV: Interleaved group requires epilogue iteration.\n"); + RequiresScalarEpilogue = true; } + } } LoopVectorizationCostModel::VectorizationFactor @@ -5255,28 +6087,22 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { // Width 1 means no vectorize VectorizationFactor Factor = {1U, 0U}; if (OptForSize && Legal->getRuntimePointerChecking()->Need) { - emitAnalysis( - VectorizationReport() - << "runtime pointer checks needed. Enable vectorization of this " - "loop with '#pragma clang loop vectorize(enable)' when " - "compiling with -Os/-Oz"); + ORE->emit(createMissedAnalysis("CantVersionLoopWithOptForSize") + << "runtime pointer checks needed. Enable vectorization of this " + "loop with '#pragma clang loop vectorize(enable)' when " + "compiling with -Os/-Oz"); DEBUG(dbgs() << "LV: Aborting. Runtime ptr check is required with -Os/-Oz.\n"); return Factor; } if (!EnableCondStoresVectorization && Legal->getNumPredStores()) { - emitAnalysis( - VectorizationReport() - << "store that is conditionally executed prevents vectorization"); + ORE->emit(createMissedAnalysis("ConditionalStore") + << "store that is conditionally executed prevents vectorization"); DEBUG(dbgs() << "LV: No vectorization. There are conditional stores.\n"); return Factor; } - // Find the trip count. - unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); - DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); - MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI); unsigned SmallestType, WidestType; std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes(); @@ -5334,10 +6160,13 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { // If we optimize the program for size, avoid creating the tail loop. if (OptForSize) { - // If we are unable to calculate the trip count then don't try to vectorize. + unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); + DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); + + // If we don't know the precise trip count, don't try to vectorize. if (TC < 2) { - emitAnalysis( - VectorizationReport() + ORE->emit( + createMissedAnalysis("UnknownLoopCountComplexCFG") << "unable to calculate the loop count due to complex control flow"); DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); return Factor; @@ -5351,11 +6180,11 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { else { // If the trip count that we found modulo the vectorization factor is not // zero then we require a tail. - emitAnalysis(VectorizationReport() - << "cannot optimize for size and vectorize at the " - "same time. Enable vectorization of this loop " - "with '#pragma clang loop vectorize(enable)' " - "when compiling with -Os/-Oz"); + ORE->emit(createMissedAnalysis("NoTailLoopWithOptForSize") + << "cannot optimize for size and vectorize at the " + "same time. Enable vectorization of this loop " + "with '#pragma clang loop vectorize(enable)' " + "when compiling with -Os/-Oz"); DEBUG(dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); return Factor; } @@ -5367,6 +6196,7 @@ LoopVectorizationCostModel::selectVectorizationFactor(bool OptForSize) { DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n"); Factor.Width = UserVF; + collectInstsToScalarize(UserVF); return Factor; } @@ -5712,15 +6542,16 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { for (unsigned int i = 0; i < Index; ++i) { Instruction *I = IdxToInstr[i]; - // Ignore instructions that are never used within the loop. - if (!Ends.count(I)) - continue; // Remove all of the instructions that end at this location. InstrList &List = TransposeEnds[i]; for (Instruction *ToRemove : List) OpenIntervals.erase(ToRemove); + // Ignore instructions that are never used within the loop. + if (!Ends.count(I)) + continue; + // Skip ignored values. if (ValuesToIgnore.count(I)) continue; @@ -5772,10 +6603,160 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { return RUs; } +void LoopVectorizationCostModel::collectInstsToScalarize(unsigned VF) { + + // If we aren't vectorizing the loop, or if we've already collected the + // instructions to scalarize, there's nothing to do. Collection may already + // have occurred if we have a user-selected VF and are now computing the + // expected cost for interleaving. + if (VF < 2 || InstsToScalarize.count(VF)) + return; + + // Initialize a mapping for VF in InstsToScalalarize. If we find that it's + // not profitable to scalarize any instructions, the presence of VF in the + // map will indicate that we've analyzed it already. + ScalarCostsTy &ScalarCostsVF = InstsToScalarize[VF]; + + // Find all the instructions that are scalar with predication in the loop and + // determine if it would be better to not if-convert the blocks they are in. + // If so, we also record the instructions to scalarize. + for (BasicBlock *BB : TheLoop->blocks()) { + if (!Legal->blockNeedsPredication(BB)) + continue; + for (Instruction &I : *BB) + if (Legal->isScalarWithPredication(&I)) { + ScalarCostsTy ScalarCosts; + if (computePredInstDiscount(&I, ScalarCosts, VF) >= 0) + ScalarCostsVF.insert(ScalarCosts.begin(), ScalarCosts.end()); + } + } +} + +int LoopVectorizationCostModel::computePredInstDiscount( + Instruction *PredInst, DenseMap<Instruction *, unsigned> &ScalarCosts, + unsigned VF) { + + assert(!Legal->isUniformAfterVectorization(PredInst) && + "Instruction marked uniform-after-vectorization will be predicated"); + + // Initialize the discount to zero, meaning that the scalar version and the + // vector version cost the same. + int Discount = 0; + + // Holds instructions to analyze. The instructions we visit are mapped in + // ScalarCosts. Those instructions are the ones that would be scalarized if + // we find that the scalar version costs less. + SmallVector<Instruction *, 8> Worklist; + + // Returns true if the given instruction can be scalarized. + auto canBeScalarized = [&](Instruction *I) -> bool { + + // We only attempt to scalarize instructions forming a single-use chain + // from the original predicated block that would otherwise be vectorized. + // Although not strictly necessary, we give up on instructions we know will + // already be scalar to avoid traversing chains that are unlikely to be + // beneficial. + if (!I->hasOneUse() || PredInst->getParent() != I->getParent() || + Legal->isScalarAfterVectorization(I)) + return false; + + // If the instruction is scalar with predication, it will be analyzed + // separately. We ignore it within the context of PredInst. + if (Legal->isScalarWithPredication(I)) + return false; + + // If any of the instruction's operands are uniform after vectorization, + // the instruction cannot be scalarized. This prevents, for example, a + // masked load from being scalarized. + // + // We assume we will only emit a value for lane zero of an instruction + // marked uniform after vectorization, rather than VF identical values. + // Thus, if we scalarize an instruction that uses a uniform, we would + // create uses of values corresponding to the lanes we aren't emitting code + // for. This behavior can be changed by allowing getScalarValue to clone + // the lane zero values for uniforms rather than asserting. + for (Use &U : I->operands()) + if (auto *J = dyn_cast<Instruction>(U.get())) + if (Legal->isUniformAfterVectorization(J)) + return false; + + // Otherwise, we can scalarize the instruction. + return true; + }; + + // Returns true if an operand that cannot be scalarized must be extracted + // from a vector. We will account for this scalarization overhead below. Note + // that the non-void predicated instructions are placed in their own blocks, + // and their return values are inserted into vectors. Thus, an extract would + // still be required. + auto needsExtract = [&](Instruction *I) -> bool { + return TheLoop->contains(I) && !Legal->isScalarAfterVectorization(I); + }; + + // Compute the expected cost discount from scalarizing the entire expression + // feeding the predicated instruction. We currently only consider expressions + // that are single-use instruction chains. + Worklist.push_back(PredInst); + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + + // If we've already analyzed the instruction, there's nothing to do. + if (ScalarCosts.count(I)) + continue; + + // Compute the cost of the vector instruction. Note that this cost already + // includes the scalarization overhead of the predicated instruction. + unsigned VectorCost = getInstructionCost(I, VF).first; + + // Compute the cost of the scalarized instruction. This cost is the cost of + // the instruction as if it wasn't if-converted and instead remained in the + // predicated block. We will scale this cost by block probability after + // computing the scalarization overhead. + unsigned ScalarCost = VF * getInstructionCost(I, 1).first; + + // Compute the scalarization overhead of needed insertelement instructions + // and phi nodes. + if (Legal->isScalarWithPredication(I) && !I->getType()->isVoidTy()) { + ScalarCost += getScalarizationOverhead(ToVectorTy(I->getType(), VF), true, + false, TTI); + ScalarCost += VF * TTI.getCFInstrCost(Instruction::PHI); + } + + // Compute the scalarization overhead of needed extractelement + // instructions. For each of the instruction's operands, if the operand can + // be scalarized, add it to the worklist; otherwise, account for the + // overhead. + for (Use &U : I->operands()) + if (auto *J = dyn_cast<Instruction>(U.get())) { + assert(VectorType::isValidElementType(J->getType()) && + "Instruction has non-scalar type"); + if (canBeScalarized(J)) + Worklist.push_back(J); + else if (needsExtract(J)) + ScalarCost += getScalarizationOverhead(ToVectorTy(J->getType(), VF), + false, true, TTI); + } + + // Scale the total scalar cost by block probability. + ScalarCost /= getReciprocalPredBlockProb(); + + // Compute the discount. A non-negative discount means the vector version + // of the instruction costs more, and scalarizing would be beneficial. + Discount += VectorCost - ScalarCost; + ScalarCosts[I] = ScalarCost; + } + + return Discount; +} + LoopVectorizationCostModel::VectorizationCostTy LoopVectorizationCostModel::expectedCost(unsigned VF) { VectorizationCostTy Cost; + // Collect the instructions (and their associated costs) that will be more + // profitable to scalarize. + collectInstsToScalarize(VF); + // For each block. for (BasicBlock *BB : TheLoop->blocks()) { VectorizationCostTy BlockCost; @@ -5802,11 +6783,14 @@ LoopVectorizationCostModel::expectedCost(unsigned VF) { << VF << " For instruction: " << I << '\n'); } - // We assume that if-converted blocks have a 50% chance of being executed. - // When the code is scalar then some of the blocks are avoided due to CF. - // When the code is vectorized we execute all code paths. + // If we are vectorizing a predicated block, it will have been + // if-converted. This means that the block's instructions (aside from + // stores and instructions that may divide by zero) will now be + // unconditionally executed. For the scalar case, we may not always execute + // the predicated block. Thus, scale the block's cost by the probability of + // executing it. if (VF == 1 && Legal->blockNeedsPredication(BB)) - BlockCost.first /= 2; + BlockCost.first /= getReciprocalPredBlockProb(); Cost.first += BlockCost.first; Cost.second |= BlockCost.second; @@ -5815,35 +6799,19 @@ LoopVectorizationCostModel::expectedCost(unsigned VF) { return Cost; } -/// \brief Check if the load/store instruction \p I may be translated into -/// gather/scatter during vectorization. -/// -/// Pointer \p Ptr specifies address in memory for the given scalar memory -/// instruction. We need it to retrieve data type. -/// Using gather/scatter is possible when it is supported by target. -static bool isGatherOrScatterLegal(Instruction *I, Value *Ptr, - LoopVectorizationLegality *Legal) { - auto *DataTy = cast<PointerType>(Ptr->getType())->getElementType(); - return (isa<LoadInst>(I) && Legal->isLegalMaskedGather(DataTy)) || - (isa<StoreInst>(I) && Legal->isLegalMaskedScatter(DataTy)); -} - -/// \brief Check whether the address computation for a non-consecutive memory -/// access looks like an unlikely candidate for being merged into the indexing -/// mode. +/// \brief Gets Address Access SCEV after verifying that the access pattern +/// is loop invariant except the induction variable dependence. /// -/// We look for a GEP which has one index that is an induction variable and all -/// other indices are loop invariant. If the stride of this access is also -/// within a small bound we decide that this address computation can likely be -/// merged into the addressing mode. -/// In all other cases, we identify the address computation as complex. -static bool isLikelyComplexAddressComputation(Value *Ptr, - LoopVectorizationLegality *Legal, - ScalarEvolution *SE, - const Loop *TheLoop) { +/// This SCEV can be sent to the Target in order to estimate the address +/// calculation cost. +static const SCEV *getAddressAccessSCEV( + Value *Ptr, + LoopVectorizationLegality *Legal, + ScalarEvolution *SE, + const Loop *TheLoop) { auto *Gep = dyn_cast<GetElementPtrInst>(Ptr); if (!Gep) - return true; + return nullptr; // We are looking for a gep with all loop invariant indices except for one // which should be an induction variable. @@ -5852,33 +6820,11 @@ static bool isLikelyComplexAddressComputation(Value *Ptr, Value *Opd = Gep->getOperand(i); if (!SE->isLoopInvariant(SE->getSCEV(Opd), TheLoop) && !Legal->isInductionVariable(Opd)) - return true; + return nullptr; } - // Now we know we have a GEP ptr, %inv, %ind, %inv. Make sure that the step - // can likely be merged into the address computation. - unsigned MaxMergeDistance = 64; - - const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Ptr)); - if (!AddRec) - return true; - - // Check the step is constant. - const SCEV *Step = AddRec->getStepRecurrence(*SE); - // Calculate the pointer stride and check if it is consecutive. - const auto *C = dyn_cast<SCEVConstant>(Step); - if (!C) - return true; - - const APInt &APStepVal = C->getAPInt(); - - // Huge step value - give up. - if (APStepVal.getBitWidth() > 64) - return true; - - int64_t StepVal = APStepVal.getSExtValue(); - - return StepVal > MaxMergeDistance; + // Now we know we have a GEP ptr, %inv, %ind, %inv. return the Ptr SCEV. + return SE->getSCEV(Ptr); } static bool isStrideMul(Instruction *I, LoopVectorizationLegality *Legal) { @@ -5893,6 +6839,9 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { if (Legal->isUniformAfterVectorization(I)) VF = 1; + if (VF > 1 && isProfitableToScalarize(I, VF)) + return VectorizationCostTy(InstsToScalarize[VF][I], false); + Type *VectorTy; unsigned C = getInstructionCost(I, VF, VectorTy); @@ -5905,7 +6854,7 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF, Type *&VectorTy) { Type *RetTy = I->getType(); - if (VF > 1 && MinBWs.count(I)) + if (canTruncateToMinimalBitwidth(I, VF)) RetTy = IntegerType::get(RetTy->getContext(), MinBWs[I]); VectorTy = ToVectorTy(RetTy, VF); auto SE = PSE.getSE(); @@ -5932,17 +6881,42 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, // TODO: IF-converted IFs become selects. return 0; } + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + // If we have a predicated instruction, it may not be executed for each + // vector lane. Get the scalarization cost and scale this amount by the + // probability of executing the predicated block. If the instruction is not + // predicated, we fall through to the next case. + if (VF > 1 && Legal->isScalarWithPredication(I)) { + unsigned Cost = 0; + + // These instructions have a non-void type, so account for the phi nodes + // that we will create. This cost is likely to be zero. The phi node + // cost, if any, should be scaled by the block probability because it + // models a copy at the end of each predicated block. + Cost += VF * TTI.getCFInstrCost(Instruction::PHI); + + // The cost of the non-predicated instruction. + Cost += VF * TTI.getArithmeticInstrCost(I->getOpcode(), RetTy); + + // The cost of insertelement and extractelement instructions needed for + // scalarization. + Cost += getScalarizationOverhead(I, VF, TTI); + + // Scale the cost by the probability of executing the predicated blocks. + // This assumes the predicated block for each vector lane is equally + // likely. + return Cost / getReciprocalPredBlockProb(); + } case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: case Instruction::FSub: case Instruction::Mul: case Instruction::FMul: - case Instruction::UDiv: - case Instruction::SDiv: case Instruction::FDiv: - case Instruction::URem: - case Instruction::SRem: case Instruction::FRem: case Instruction::Shl: case Instruction::LShr: @@ -5965,7 +6939,7 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, TargetTransformInfo::OP_None; Value *Op2 = I->getOperand(1); - // Check for a splat of a constant or for a non uniform vector of constants. + // Check for a splat or for a non uniform vector of constants. if (isa<ConstantInt>(Op2)) { ConstantInt *CInt = cast<ConstantInt>(Op2); if (CInt && CInt->getValue().isPowerOf2()) @@ -5980,10 +6954,12 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, Op2VP = TargetTransformInfo::OP_PowerOf2; Op2VK = TargetTransformInfo::OK_UniformConstantValue; } + } else if (Legal->isUniform(Op2)) { + Op2VK = TargetTransformInfo::OK_UniformValue; } - - return TTI.getArithmeticInstrCost(I->getOpcode(), VectorTy, Op1VK, Op2VK, - Op1VP, Op2VP); + SmallVector<const Value *, 4> Operands(I->operand_values()); + return TTI.getArithmeticInstrCost(I->getOpcode(), VectorTy, Op1VK, + Op2VK, Op1VP, Op2VP, Operands); } case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); @@ -5999,9 +6975,8 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, case Instruction::FCmp: { Type *ValTy = I->getOperand(0)->getType(); Instruction *Op0AsInstruction = dyn_cast<Instruction>(I->getOperand(0)); - auto It = MinBWs.find(Op0AsInstruction); - if (VF > 1 && It != MinBWs.end()) - ValTy = IntegerType::get(ValTy->getContext(), It->second); + if (canTruncateToMinimalBitwidth(Op0AsInstruction, VF)) + ValTy = IntegerType::get(ValTy->getContext(), MinBWs[Op0AsInstruction]); VectorTy = ToVectorTy(ValTy, VF); return TTI.getCmpSelInstrCost(I->getOpcode(), VectorTy); } @@ -6015,7 +6990,7 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned Alignment = SI ? SI->getAlignment() : LI->getAlignment(); unsigned AS = SI ? SI->getPointerAddressSpace() : LI->getPointerAddressSpace(); - Value *Ptr = SI ? SI->getPointerOperand() : LI->getPointerOperand(); + Value *Ptr = getPointerOperand(I); // We add the cost of address computation here instead of with the gep // instruction because only here we know whether the operation is // scalarized. @@ -6072,41 +7047,43 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, return Cost; } - // Scalarized loads/stores. - int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); - bool UseGatherOrScatter = - (ConsecutiveStride == 0) && isGatherOrScatterLegal(I, Ptr, Legal); - - bool Reverse = ConsecutiveStride < 0; - const DataLayout &DL = I->getModule()->getDataLayout(); - uint64_t ScalarAllocatedSize = DL.getTypeAllocSize(ValTy); - uint64_t VectorElementSize = DL.getTypeStoreSize(VectorTy) / VF; - if ((!ConsecutiveStride && !UseGatherOrScatter) || - ScalarAllocatedSize != VectorElementSize) { - bool IsComplexComputation = - isLikelyComplexAddressComputation(Ptr, Legal, SE, TheLoop); + // Check if the memory instruction will be scalarized. + if (Legal->memoryInstructionMustBeScalarized(I, VF)) { unsigned Cost = 0; - // The cost of extracting from the value vector and pointer vector. Type *PtrTy = ToVectorTy(Ptr->getType(), VF); - for (unsigned i = 0; i < VF; ++i) { - // The cost of extracting the pointer operand. - Cost += TTI.getVectorInstrCost(Instruction::ExtractElement, PtrTy, i); - // In case of STORE, the cost of ExtractElement from the vector. - // In case of LOAD, the cost of InsertElement into the returned - // vector. - Cost += TTI.getVectorInstrCost(SI ? Instruction::ExtractElement - : Instruction::InsertElement, - VectorTy, i); - } - // The cost of the scalar loads/stores. - Cost += VF * TTI.getAddressComputationCost(PtrTy, IsComplexComputation); + // Figure out whether the access is strided and get the stride value + // if it's known in compile time + const SCEV *PtrSCEV = getAddressAccessSCEV(Ptr, Legal, SE, TheLoop); + + // Get the cost of the scalar memory instruction and address computation. + Cost += VF * TTI.getAddressComputationCost(PtrTy, SE, PtrSCEV); Cost += VF * TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), Alignment, AS); + + // Get the overhead of the extractelement and insertelement instructions + // we might create due to scalarization. + Cost += getScalarizationOverhead(I, VF, TTI); + + // If we have a predicated store, it may not be executed for each vector + // lane. Scale the cost by the probability of executing the predicated + // block. + if (Legal->isScalarWithPredication(I)) + Cost /= getReciprocalPredBlockProb(); + return Cost; } + // Determine if the pointer operand of the access is either consecutive or + // reverse consecutive. + int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); + bool Reverse = ConsecutiveStride < 0; + + // Determine if either a gather or scatter operation is legal. + bool UseGatherOrScatter = + !ConsecutiveStride && Legal->isLegalGatherOrScatter(I); + unsigned Cost = TTI.getAddressComputationCost(VectorTy); if (UseGatherOrScatter) { assert(ConsecutiveStride == 0 && @@ -6147,7 +7124,7 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, Type *SrcScalarTy = I->getOperand(0)->getType(); Type *SrcVecTy = ToVectorTy(SrcScalarTy, VF); - if (VF > 1 && MinBWs.count(I)) { + if (canTruncateToMinimalBitwidth(I, VF)) { // This cast is going to be shrunk. This may remove the cast or it might // turn it into slightly different cast. For example, if MinBW == 16, // "zext i8 %1 to i32" becomes "zext i8 %1 to i16". @@ -6176,28 +7153,11 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, return std::min(CallCost, getVectorIntrinsicCost(CI, VF, TTI, TLI)); return CallCost; } - default: { - // We are scalarizing the instruction. Return the cost of the scalar - // instruction, plus the cost of insert and extract into vector - // elements, times the vector width. - unsigned Cost = 0; - - if (!RetTy->isVoidTy() && VF != 1) { - unsigned InsCost = - TTI.getVectorInstrCost(Instruction::InsertElement, VectorTy); - unsigned ExtCost = - TTI.getVectorInstrCost(Instruction::ExtractElement, VectorTy); - - // The cost of inserting the results plus extracting each one of the - // operands. - Cost += VF * (InsCost + ExtCost * I->getNumOperands()); - } - + default: // The cost of executing VF copies of the scalar instruction. This opcode // is unknown. Assume that it is the same as 'mul'. - Cost += VF * TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy); - return Cost; - } + return VF * TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy) + + getScalarizationOverhead(I, VF, TTI); } // end of switch. } @@ -6217,6 +7177,7 @@ INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(LoopVectorize, LV_NAME, lv_name, false, false) namespace llvm { @@ -6226,14 +7187,11 @@ Pass *createLoopVectorizePass(bool NoUnrolling, bool AlwaysVectorize) { } bool LoopVectorizationCostModel::isConsecutiveLoadOrStore(Instruction *Inst) { - // Check for a store. - if (auto *ST = dyn_cast<StoreInst>(Inst)) - return Legal->isConsecutivePtr(ST->getPointerOperand()) != 0; - - // Check for a load. - if (auto *LI = dyn_cast<LoadInst>(Inst)) - return Legal->isConsecutivePtr(LI->getPointerOperand()) != 0; + // Check if the pointer operand of a load or store instruction is + // consecutive. + if (auto *Ptr = getPointerOperand(Inst)) + return Legal->isConsecutivePtr(Ptr); return false; } @@ -6249,123 +7207,46 @@ void LoopVectorizationCostModel::collectValuesToIgnore() { VecValuesToIgnore.insert(Casts.begin(), Casts.end()); } - // Ignore induction phis that are only used in either GetElementPtr or ICmp - // instruction to exit loop. Induction variables usually have large types and - // can have big impact when estimating register usage. - // This is for when VF > 1. - for (auto &Induction : *Legal->getInductionVars()) { - auto *PN = Induction.first; - auto *UpdateV = PN->getIncomingValueForBlock(TheLoop->getLoopLatch()); - - // Check that the PHI is only used by the induction increment (UpdateV) or - // by GEPs. Then check that UpdateV is only used by a compare instruction, - // the loop header PHI, or by GEPs. - // FIXME: Need precise def-use analysis to determine if this instruction - // variable will be vectorized. - if (all_of(PN->users(), - [&](const User *U) -> bool { - return U == UpdateV || isa<GetElementPtrInst>(U); - }) && - all_of(UpdateV->users(), [&](const User *U) -> bool { - return U == PN || isa<ICmpInst>(U) || isa<GetElementPtrInst>(U); - })) { - VecValuesToIgnore.insert(PN); - VecValuesToIgnore.insert(UpdateV); - } - } - - // Ignore instructions that will not be vectorized. - // This is for when VF > 1. - for (BasicBlock *BB : TheLoop->blocks()) { - for (auto &Inst : *BB) { - switch (Inst.getOpcode()) - case Instruction::GetElementPtr: { - // Ignore GEP if its last operand is an induction variable so that it is - // a consecutive load/store and won't be vectorized as scatter/gather - // pattern. - - GetElementPtrInst *Gep = cast<GetElementPtrInst>(&Inst); - unsigned NumOperands = Gep->getNumOperands(); - unsigned InductionOperand = getGEPInductionOperand(Gep); - bool GepToIgnore = true; - - // Check that all of the gep indices are uniform except for the - // induction operand. - for (unsigned i = 0; i != NumOperands; ++i) { - if (i != InductionOperand && - !PSE.getSE()->isLoopInvariant(PSE.getSCEV(Gep->getOperand(i)), - TheLoop)) { - GepToIgnore = false; - break; - } - } - - if (GepToIgnore) - VecValuesToIgnore.insert(&Inst); - break; - } - } - } + // Insert values known to be scalar into VecValuesToIgnore. This is a + // conservative estimation of the values that will later be scalarized. + // + // FIXME: Even though an instruction is not scalar-after-vectoriztion, it may + // still be scalarized. For example, we may find an instruction to be + // more profitable for a given vectorization factor if it were to be + // scalarized. But at this point, we haven't yet computed the + // vectorization factor. + for (auto *BB : TheLoop->getBlocks()) + for (auto &I : *BB) + if (Legal->isScalarAfterVectorization(&I)) + VecValuesToIgnore.insert(&I); } void InnerLoopUnroller::scalarizeInstruction(Instruction *Instr, - bool IfPredicateStore) { + bool IfPredicateInstr) { assert(!Instr->getType()->isAggregateType() && "Can't handle vectors"); // Holds vector parameters or scalars, in case of uniform vals. SmallVector<VectorParts, 4> Params; setDebugLocFromInst(Builder, Instr); - // Find all of the vectorized parameters. - for (Value *SrcOp : Instr->operands()) { - // If we are accessing the old induction variable, use the new one. - if (SrcOp == OldInduction) { - Params.push_back(getVectorValue(SrcOp)); - continue; - } - - // Try using previously calculated values. - Instruction *SrcInst = dyn_cast<Instruction>(SrcOp); - - // If the src is an instruction that appeared earlier in the basic block - // then it should already be vectorized. - if (SrcInst && OrigLoop->contains(SrcInst)) { - assert(WidenMap.has(SrcInst) && "Source operand is unavailable"); - // The parameter is a vector value from earlier. - Params.push_back(WidenMap.get(SrcInst)); - } else { - // The parameter is a scalar from outside the loop. Maybe even a constant. - VectorParts Scalars; - Scalars.append(UF, SrcOp); - Params.push_back(Scalars); - } - } - - assert(Params.size() == Instr->getNumOperands() && - "Invalid number of operands"); - // Does this instruction return a value ? bool IsVoidRetTy = Instr->getType()->isVoidTy(); - Value *UndefVec = IsVoidRetTy ? nullptr : UndefValue::get(Instr->getType()); - // Create a new entry in the WidenMap and initialize it to Undef or Null. - VectorParts &VecResults = WidenMap.splat(Instr, UndefVec); + // Initialize a new scalar map entry. + ScalarParts Entry(UF); VectorParts Cond; - if (IfPredicateStore) { - assert(Instr->getParent()->getSinglePredecessor() && - "Only support single predecessor blocks"); - Cond = createEdgeMask(Instr->getParent()->getSinglePredecessor(), - Instr->getParent()); - } + if (IfPredicateInstr) + Cond = createBlockInMask(Instr->getParent()); // For each vector unroll 'part': for (unsigned Part = 0; Part < UF; ++Part) { + Entry[Part].resize(1); // For each scalar that we create: // Start an "if (pred) a[i] = ..." block. Value *Cmp = nullptr; - if (IfPredicateStore) { + if (IfPredicateInstr) { if (Cond[Part]->getType()->isVectorTy()) Cond[Part] = Builder.CreateExtractElement(Cond[Part], Builder.getInt32(0)); @@ -6376,47 +7257,57 @@ void InnerLoopUnroller::scalarizeInstruction(Instruction *Instr, Instruction *Cloned = Instr->clone(); if (!IsVoidRetTy) Cloned->setName(Instr->getName() + ".cloned"); - // Replace the operands of the cloned instructions with extracted scalars. + + // Replace the operands of the cloned instructions with their scalar + // equivalents in the new loop. for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { - Value *Op = Params[op][Part]; - Cloned->setOperand(op, Op); + auto *NewOp = getScalarValue(Instr->getOperand(op), Part, 0); + Cloned->setOperand(op, NewOp); } // Place the cloned scalar in the new loop. Builder.Insert(Cloned); + // Add the cloned scalar to the scalar map entry. + Entry[Part][0] = Cloned; + // If we just cloned a new assumption, add it the assumption cache. if (auto *II = dyn_cast<IntrinsicInst>(Cloned)) if (II->getIntrinsicID() == Intrinsic::assume) AC->registerAssumption(II); - // If the original scalar returns a value we need to place it in a vector - // so that future users will be able to use it. - if (!IsVoidRetTy) - VecResults[Part] = Cloned; - // End if-block. - if (IfPredicateStore) - PredicatedStores.push_back(std::make_pair(cast<StoreInst>(Cloned), Cmp)); + if (IfPredicateInstr) + PredicatedInstructions.push_back(std::make_pair(Cloned, Cmp)); } + VectorLoopValueMap.initScalar(Instr, Entry); } void InnerLoopUnroller::vectorizeMemoryInstruction(Instruction *Instr) { auto *SI = dyn_cast<StoreInst>(Instr); - bool IfPredicateStore = (SI && Legal->blockNeedsPredication(SI->getParent())); + bool IfPredicateInstr = (SI && Legal->blockNeedsPredication(SI->getParent())); - return scalarizeInstruction(Instr, IfPredicateStore); + return scalarizeInstruction(Instr, IfPredicateInstr); } Value *InnerLoopUnroller::reverseVector(Value *Vec) { return Vec; } Value *InnerLoopUnroller::getBroadcastInstrs(Value *V) { return V; } -Value *InnerLoopUnroller::getStepVector(Value *Val, int StartIdx, Value *Step) { +Value *InnerLoopUnroller::getStepVector(Value *Val, int StartIdx, Value *Step, + Instruction::BinaryOps BinOp) { // When unrolling and the VF is 1, we only need to add a simple scalar. - Type *ITy = Val->getType(); - assert(!ITy->isVectorTy() && "Val must be a scalar"); - Constant *C = ConstantInt::get(ITy, StartIdx); + Type *Ty = Val->getType(); + assert(!Ty->isVectorTy() && "Val must be a scalar"); + + if (Ty->isFloatingPointTy()) { + Constant *C = ConstantFP::get(Ty, (double)StartIdx); + + // Floating point operations had to be 'fast' to enable the unrolling. + Value *MulOp = addFastMathFlag(Builder.CreateFMul(C, Step)); + return addFastMathFlag(Builder.CreateBinOp(BinOp, Val, MulOp)); + } + Constant *C = ConstantInt::get(Ty, StartIdx); return Builder.CreateAdd(Val, Builder.CreateMul(C, Step), "induction"); } @@ -6465,7 +7356,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { << L->getHeader()->getParent()->getName() << "\" from " << DebugLocStr << "\n"); - LoopVectorizeHints Hints(L, DisableUnrolling); + LoopVectorizeHints Hints(L, DisableUnrolling, *ORE); DEBUG(dbgs() << "LV: Loop hints:" << " force=" @@ -6483,8 +7374,8 @@ bool LoopVectorizePass::processLoop(Loop *L) { // Looking at the diagnostic output is the only way to determine if a loop // was vectorized (other than looking at the IR or machine code), so it // is important to generate an optimization remark for each loop. Most of - // these messages are generated by emitOptimizationRemarkAnalysis. Remarks - // generated by emitOptimizationRemark and emitOptimizationRemarkMissed are + // these messages are generated as OptimizationRemarkAnalysis. Remarks + // generated as OptimizationRemark and OptimizationRemarkMissed are // less verbose reporting vectorized loops and unvectorized loops that may // benefit from vectorization, respectively. @@ -6495,17 +7386,18 @@ bool LoopVectorizePass::processLoop(Loop *L) { // Check the loop for a trip count threshold: // do not vectorize loops with a tiny trip count. - const unsigned TC = SE->getSmallConstantTripCount(L); - if (TC > 0u && TC < TinyTripCountVectorThreshold) { + const unsigned MaxTC = SE->getSmallConstantMaxTripCount(L); + if (MaxTC > 0u && MaxTC < TinyTripCountVectorThreshold) { DEBUG(dbgs() << "LV: Found a loop with a very small trip count. " << "This loop is not worth vectorizing."); if (Hints.getForce() == LoopVectorizeHints::FK_Enabled) DEBUG(dbgs() << " But vectorizing was explicitly forced.\n"); else { DEBUG(dbgs() << "\n"); - emitAnalysisDiag(F, L, Hints, VectorizationReport() - << "vectorization is not beneficial " - "and is not explicitly forced"); + ORE->emit(createMissedAnalysis(Hints.vectorizeAnalysisPassName(), + "NotBeneficial", L) + << "vectorization is not beneficial " + "and is not explicitly forced"); return false; } } @@ -6513,17 +7405,17 @@ bool LoopVectorizePass::processLoop(Loop *L) { PredicatedScalarEvolution PSE(*SE, *L); // Check if it is legal to vectorize the loop. - LoopVectorizationRequirements Requirements; - LoopVectorizationLegality LVL(L, PSE, DT, TLI, AA, F, TTI, GetLAA, LI, + LoopVectorizationRequirements Requirements(*ORE); + LoopVectorizationLegality LVL(L, PSE, DT, TLI, AA, F, TTI, GetLAA, LI, ORE, &Requirements, &Hints); if (!LVL.canVectorize()) { DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); - emitMissedWarning(F, L, Hints); + emitMissedWarning(F, L, Hints, ORE); return false; } // Use the cost model. - LoopVectorizationCostModel CM(L, PSE, LI, &LVL, *TTI, TLI, DB, AC, F, + LoopVectorizationCostModel CM(L, PSE, LI, &LVL, *TTI, TLI, DB, AC, ORE, F, &Hints); CM.collectValuesToIgnore(); @@ -6551,11 +7443,10 @@ bool LoopVectorizePass::processLoop(Loop *L) { if (F->hasFnAttribute(Attribute::NoImplicitFloat)) { DEBUG(dbgs() << "LV: Can't vectorize when the NoImplicitFloat" "attribute is used.\n"); - emitAnalysisDiag( - F, L, Hints, - VectorizationReport() - << "loop not vectorized due to NoImplicitFloat attribute"); - emitMissedWarning(F, L, Hints); + ORE->emit(createMissedAnalysis(Hints.vectorizeAnalysisPassName(), + "NoImplicitFloat", L) + << "loop not vectorized due to NoImplicitFloat attribute"); + emitMissedWarning(F, L, Hints, ORE); return false; } @@ -6566,10 +7457,10 @@ bool LoopVectorizePass::processLoop(Loop *L) { if (Hints.isPotentiallyUnsafe() && TTI->isFPVectorizationPotentiallyUnsafe()) { DEBUG(dbgs() << "LV: Potentially unsafe FP op prevents vectorization.\n"); - emitAnalysisDiag(F, L, Hints, - VectorizationReport() - << "loop not vectorized due to unsafe FP support."); - emitMissedWarning(F, L, Hints); + ORE->emit( + createMissedAnalysis(Hints.vectorizeAnalysisPassName(), "UnsafeFP", L) + << "loop not vectorized due to unsafe FP support."); + emitMissedWarning(F, L, Hints, ORE); return false; } @@ -6584,38 +7475,43 @@ bool LoopVectorizePass::processLoop(Loop *L) { unsigned UserIC = Hints.getInterleave(); // Identify the diagnostic messages that should be produced. - std::string VecDiagMsg, IntDiagMsg; + std::pair<StringRef, std::string> VecDiagMsg, IntDiagMsg; bool VectorizeLoop = true, InterleaveLoop = true; - if (Requirements.doesNotMeet(F, L, Hints)) { DEBUG(dbgs() << "LV: Not vectorizing: loop did not meet vectorization " "requirements.\n"); - emitMissedWarning(F, L, Hints); + emitMissedWarning(F, L, Hints, ORE); return false; } if (VF.Width == 1) { DEBUG(dbgs() << "LV: Vectorization is possible but not beneficial.\n"); - VecDiagMsg = - "the cost-model indicates that vectorization is not beneficial"; + VecDiagMsg = std::make_pair( + "VectorizationNotBeneficial", + "the cost-model indicates that vectorization is not beneficial"); VectorizeLoop = false; } if (IC == 1 && UserIC <= 1) { // Tell the user interleaving is not beneficial. DEBUG(dbgs() << "LV: Interleaving is not beneficial.\n"); - IntDiagMsg = - "the cost-model indicates that interleaving is not beneficial"; + IntDiagMsg = std::make_pair( + "InterleavingNotBeneficial", + "the cost-model indicates that interleaving is not beneficial"); InterleaveLoop = false; - if (UserIC == 1) - IntDiagMsg += + if (UserIC == 1) { + IntDiagMsg.first = "InterleavingNotBeneficialAndDisabled"; + IntDiagMsg.second += " and is explicitly disabled or interleave count is set to 1"; + } } else if (IC > 1 && UserIC == 1) { // Tell the user interleaving is beneficial, but it explicitly disabled. DEBUG(dbgs() << "LV: Interleaving is beneficial but is explicitly disabled."); - IntDiagMsg = "the cost-model indicates that interleaving is beneficial " - "but is explicitly disabled or interleave count is set to 1"; + IntDiagMsg = std::make_pair( + "InterleavingBeneficialButDisabled", + "the cost-model indicates that interleaving is beneficial " + "but is explicitly disabled or interleave count is set to 1"); InterleaveLoop = false; } @@ -6626,40 +7522,48 @@ bool LoopVectorizePass::processLoop(Loop *L) { const char *VAPassName = Hints.vectorizeAnalysisPassName(); if (!VectorizeLoop && !InterleaveLoop) { // Do not vectorize or interleaving the loop. - emitOptimizationRemarkAnalysis(F->getContext(), VAPassName, *F, - L->getStartLoc(), VecDiagMsg); - emitOptimizationRemarkAnalysis(F->getContext(), LV_NAME, *F, - L->getStartLoc(), IntDiagMsg); + ORE->emit(OptimizationRemarkAnalysis(VAPassName, VecDiagMsg.first, + L->getStartLoc(), L->getHeader()) + << VecDiagMsg.second); + ORE->emit(OptimizationRemarkAnalysis(LV_NAME, IntDiagMsg.first, + L->getStartLoc(), L->getHeader()) + << IntDiagMsg.second); return false; } else if (!VectorizeLoop && InterleaveLoop) { DEBUG(dbgs() << "LV: Interleave Count is " << IC << '\n'); - emitOptimizationRemarkAnalysis(F->getContext(), VAPassName, *F, - L->getStartLoc(), VecDiagMsg); + ORE->emit(OptimizationRemarkAnalysis(VAPassName, VecDiagMsg.first, + L->getStartLoc(), L->getHeader()) + << VecDiagMsg.second); } else if (VectorizeLoop && !InterleaveLoop) { DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width << ") in " << DebugLocStr << '\n'); - emitOptimizationRemarkAnalysis(F->getContext(), LV_NAME, *F, - L->getStartLoc(), IntDiagMsg); + ORE->emit(OptimizationRemarkAnalysis(LV_NAME, IntDiagMsg.first, + L->getStartLoc(), L->getHeader()) + << IntDiagMsg.second); } else if (VectorizeLoop && InterleaveLoop) { DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width << ") in " << DebugLocStr << '\n'); DEBUG(dbgs() << "LV: Interleave Count is " << IC << '\n'); } + using namespace ore; if (!VectorizeLoop) { assert(IC > 1 && "interleave count should not be 1 or 0"); // If we decided that it is not legal to vectorize the loop, then // interleave it. - InnerLoopUnroller Unroller(L, PSE, LI, DT, TLI, TTI, AC, IC); - Unroller.vectorize(&LVL, CM.MinBWs, CM.VecValuesToIgnore); - - emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(), - Twine("interleaved loop (interleaved count: ") + - Twine(IC) + ")"); + InnerLoopUnroller Unroller(L, PSE, LI, DT, TLI, TTI, AC, ORE, IC, &LVL, + &CM); + Unroller.vectorize(); + + ORE->emit(OptimizationRemark(LV_NAME, "Interleaved", L->getStartLoc(), + L->getHeader()) + << "interleaved loop (interleaved count: " + << NV("InterleaveCount", IC) << ")"); } else { // If we decided that it is *legal* to vectorize the loop, then do it. - InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, VF.Width, IC); - LB.vectorize(&LVL, CM.MinBWs, CM.VecValuesToIgnore); + InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, ORE, VF.Width, IC, + &LVL, &CM); + LB.vectorize(); ++LoopsVectorized; // Add metadata to disable runtime unrolling a scalar loop when there are @@ -6669,10 +7573,11 @@ bool LoopVectorizePass::processLoop(Loop *L) { AddRuntimeUnrollDisableMetaData(L); // Report the vectorization decision. - emitOptimizationRemark(F->getContext(), LV_NAME, *F, L->getStartLoc(), - Twine("vectorized loop (vectorization width: ") + - Twine(VF.Width) + ", interleaved count: " + - Twine(IC) + ")"); + ORE->emit(OptimizationRemark(LV_NAME, "Vectorized", L->getStartLoc(), + L->getHeader()) + << "vectorized loop (vectorization width: " + << NV("VectorizationFactor", VF.Width) + << ", interleaved count: " << NV("InterleaveCount", IC) << ")"); } // Mark the loop as already vectorized to avoid vectorizing again. @@ -6686,7 +7591,8 @@ bool LoopVectorizePass::runImpl( Function &F, ScalarEvolution &SE_, LoopInfo &LI_, TargetTransformInfo &TTI_, DominatorTree &DT_, BlockFrequencyInfo &BFI_, TargetLibraryInfo *TLI_, DemandedBits &DB_, AliasAnalysis &AA_, AssumptionCache &AC_, - std::function<const LoopAccessInfo &(Loop &)> &GetLAA_) { + std::function<const LoopAccessInfo &(Loop &)> &GetLAA_, + OptimizationRemarkEmitter &ORE_) { SE = &SE_; LI = &LI_; @@ -6698,6 +7604,7 @@ bool LoopVectorizePass::runImpl( AC = &AC_; GetLAA = &GetLAA_; DB = &DB_; + ORE = &ORE_; // Compute some weights outside of the loop over the loops. Compute this // using a BranchProbability to re-use its scaling math. @@ -6742,17 +7649,20 @@ PreservedAnalyses LoopVectorizePass::run(Function &F, auto &TTI = AM.getResult<TargetIRAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &BFI = AM.getResult<BlockFrequencyAnalysis>(F); - auto *TLI = AM.getCachedResult<TargetLibraryAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); auto &AC = AM.getResult<AssumptionAnalysis>(F); auto &DB = AM.getResult<DemandedBitsAnalysis>(F); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); std::function<const LoopAccessInfo &(Loop &)> GetLAA = [&](Loop &L) -> const LoopAccessInfo & { - return LAM.getResult<LoopAccessAnalysis>(L); + LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, TLI, TTI}; + return LAM.getResult<LoopAccessAnalysis>(L, AR); }; - bool Changed = runImpl(F, SE, LI, TTI, DT, BFI, TLI, DB, AA, AC, GetLAA); + bool Changed = + runImpl(F, SE, LI, TTI, DT, BFI, &TLI, DB, AA, AC, GetLAA, ORE); if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; diff --git a/contrib/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/contrib/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index fb94f79..328f270 100644 --- a/contrib/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/contrib/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -115,22 +115,22 @@ static bool isValidElementType(Type *Ty) { !Ty->isPPC_FP128Ty(); } -/// \returns the parent basic block if all of the instructions in \p VL -/// are in the same block or null otherwise. -static BasicBlock *getSameBlock(ArrayRef<Value *> VL) { +/// \returns true if all of the instructions in \p VL are in the same block or +/// false otherwise. +static bool allSameBlock(ArrayRef<Value *> VL) { Instruction *I0 = dyn_cast<Instruction>(VL[0]); if (!I0) - return nullptr; + return false; BasicBlock *BB = I0->getParent(); for (int i = 1, e = VL.size(); i < e; i++) { Instruction *I = dyn_cast<Instruction>(VL[i]); if (!I) - return nullptr; + return false; if (BB != I->getParent()) - return nullptr; + return false; } - return BB; + return true; } /// \returns True if all of the values in \p VL are constants. @@ -211,12 +211,12 @@ static unsigned getSameOpcode(ArrayRef<Value *> VL) { /// of each scalar operation (VL) that will be converted into a vector (I). /// Flag set: NSW, NUW, exact, and all of fast-math. static void propagateIRFlags(Value *I, ArrayRef<Value *> VL) { - if (auto *VecOp = dyn_cast<BinaryOperator>(I)) { - if (auto *Intersection = dyn_cast<BinaryOperator>(VL[0])) { + if (auto *VecOp = dyn_cast<Instruction>(I)) { + if (auto *Intersection = dyn_cast<Instruction>(VL[0])) { // Intersection is initialized to the 0th scalar, // so start counting from index '1'. for (int i = 1, e = VL.size(); i < e; ++i) { - if (auto *Scalar = dyn_cast<BinaryOperator>(VL[i])) + if (auto *Scalar = dyn_cast<Instruction>(VL[i])) Intersection->andIRFlags(Scalar); } VecOp->copyIRFlags(Intersection); @@ -224,15 +224,15 @@ static void propagateIRFlags(Value *I, ArrayRef<Value *> VL) { } } -/// \returns The type that all of the values in \p VL have or null if there -/// are different types. -static Type* getSameType(ArrayRef<Value *> VL) { +/// \returns true if all of the values in \p VL have the same type or false +/// otherwise. +static bool allSameType(ArrayRef<Value *> VL) { Type *Ty = VL[0]->getType(); for (int i = 1, e = VL.size(); i < e; i++) if (VL[i]->getType() != Ty) - return nullptr; + return false; - return Ty; + return true; } /// \returns True if Extract{Value,Element} instruction extracts element Idx. @@ -393,6 +393,10 @@ public: /// \returns number of elements in vector if isomorphism exists, 0 otherwise. unsigned canMapToVector(Type *T, const DataLayout &DL) const; + /// \returns True if the VectorizableTree is both tiny and not fully + /// vectorizable. We do not vectorize such trees. + bool isTreeTinyAndNotFullyVectorizable(); + private: struct TreeEntry; @@ -883,10 +887,10 @@ private: /// List of users to ignore during scheduling and that don't need extracting. ArrayRef<Value *> UserIgnoreList; - // Number of load-bundles, which contain consecutive loads. + // Number of load bundles that contain consecutive loads. int NumLoadsWantToKeepOrder; - // Number of load-bundles of size 2, which are consecutive loads if reversed. + // Number of load bundles that contain consecutive loads in reversed order. int NumLoadsWantToChangeOrder; // Analysis and block reference. @@ -906,8 +910,11 @@ private: IRBuilder<> Builder; /// A map of scalar integer values to the smallest bit width with which they - /// can legally be represented. - MapVector<Value *, uint64_t> MinBWs; + /// can legally be represented. The values map to (width, signed) pairs, + /// where "width" indicates the minimum bit width and "signed" is True if the + /// value must be signed-extended, rather than zero-extended, back to its + /// original width. + MapVector<Value *, std::pair<uint64_t, bool>> MinBWs; }; } // end namespace llvm @@ -917,7 +924,7 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, ArrayRef<Value *> UserIgnoreLst) { deleteTree(); UserIgnoreList = UserIgnoreLst; - if (!getSameType(Roots)) + if (!allSameType(Roots)) return; buildTree_rec(Roots, 0); @@ -958,8 +965,7 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, } // Ignore users in the user ignore list. - if (std::find(UserIgnoreList.begin(), UserIgnoreList.end(), UserInst) != - UserIgnoreList.end()) + if (is_contained(UserIgnoreList, UserInst)) continue; DEBUG(dbgs() << "SLP: Need to extract:" << *U << " from lane " << @@ -972,9 +978,8 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { - bool SameTy = allConstant(VL) || getSameType(VL); (void)SameTy; bool isAltShuffle = false; - assert(SameTy && "Invalid types!"); + assert((allConstant(VL) || allSameType(VL)) && "Invalid types!"); if (Depth == RecursionMaxDepth) { DEBUG(dbgs() << "SLP: Gathering due to max recursion depth.\n"); @@ -1007,7 +1012,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { } // If all of the operands are identical or constant we have a simple solution. - if (allConstant(VL) || isSplat(VL) || !getSameBlock(VL) || !Opcode) { + if (allConstant(VL) || isSplat(VL) || !allSameBlock(VL) || !Opcode) { DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O. \n"); newTreeEntry(VL, false); return; @@ -1159,7 +1164,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n"); return; } - // Check if the loads are consecutive or of we need to swizzle them. + + // Make sure all loads in the bundle are simple - we can't vectorize + // atomic or volatile loads. for (unsigned i = 0, e = VL.size() - 1; i < e; ++i) { LoadInst *L = cast<LoadInst>(VL[i]); if (!L->isSimple()) { @@ -1168,20 +1175,47 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth) { DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n"); return; } + } + // Check if the loads are consecutive, reversed, or neither. + // TODO: What we really want is to sort the loads, but for now, check + // the two likely directions. + bool Consecutive = true; + bool ReverseConsecutive = true; + for (unsigned i = 0, e = VL.size() - 1; i < e; ++i) { if (!isConsecutiveAccess(VL[i], VL[i + 1], *DL, *SE)) { - if (VL.size() == 2 && isConsecutiveAccess(VL[1], VL[0], *DL, *SE)) { - ++NumLoadsWantToChangeOrder; - } - BS.cancelScheduling(VL); - newTreeEntry(VL, false); - DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n"); - return; + Consecutive = false; + break; + } else { + ReverseConsecutive = false; } } - ++NumLoadsWantToKeepOrder; - newTreeEntry(VL, true); - DEBUG(dbgs() << "SLP: added a vector of loads.\n"); + + if (Consecutive) { + ++NumLoadsWantToKeepOrder; + newTreeEntry(VL, true); + DEBUG(dbgs() << "SLP: added a vector of loads.\n"); + return; + } + + // If none of the load pairs were consecutive when checked in order, + // check the reverse order. + if (ReverseConsecutive) + for (unsigned i = VL.size() - 1; i > 0; --i) + if (!isConsecutiveAccess(VL[i], VL[i - 1], *DL, *SE)) { + ReverseConsecutive = false; + break; + } + + BS.cancelScheduling(VL); + newTreeEntry(VL, false); + + if (ReverseConsecutive) { + ++NumLoadsWantToChangeOrder; + DEBUG(dbgs() << "SLP: Gathering reversed loads.\n"); + } else { + DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n"); + } return; } case Instruction::ZExt: @@ -1541,8 +1575,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { // If we have computed a smaller type for the expression, update VecTy so // that the costs will be accurate. if (MinBWs.count(VL[0])) - VecTy = VectorType::get(IntegerType::get(F->getContext(), MinBWs[VL[0]]), - VL.size()); + VecTy = VectorType::get( + IntegerType::get(F->getContext(), MinBWs[VL[0]].first), VL.size()); if (E->NeedToGather) { if (allConstant(VL)) @@ -1553,7 +1587,7 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { return getGatherCost(E->Scalars); } unsigned Opcode = getSameOpcode(VL); - assert(Opcode && getSameType(VL) && getSameBlock(VL) && "Invalid VL"); + assert(Opcode && allSameType(VL) && allSameBlock(VL) && "Invalid VL"); Instruction *VL0 = cast<Instruction>(VL[0]); switch (Opcode) { case Instruction::PHI: { @@ -1762,7 +1796,10 @@ bool BoUpSLP::isFullyVectorizableTinyTree() { DEBUG(dbgs() << "SLP: Check whether the tree with height " << VectorizableTree.size() << " is fully vectorizable .\n"); - // We only handle trees of height 2. + // We only handle trees of heights 1 and 2. + if (VectorizableTree.size() == 1 && !VectorizableTree[0].NeedToGather) + return true; + if (VectorizableTree.size() != 2) return false; @@ -1779,6 +1816,27 @@ bool BoUpSLP::isFullyVectorizableTinyTree() { return true; } +bool BoUpSLP::isTreeTinyAndNotFullyVectorizable() { + + // We can vectorize the tree if its size is greater than or equal to the + // minimum size specified by the MinTreeSize command line option. + if (VectorizableTree.size() >= MinTreeSize) + return false; + + // If we have a tiny tree (a tree whose size is less than MinTreeSize), we + // can vectorize it if we can prove it fully vectorizable. + if (isFullyVectorizableTinyTree()) + return false; + + assert(VectorizableTree.empty() + ? ExternalUses.empty() + : true && "We shouldn't have any external users"); + + // Otherwise, we can't vectorize the tree. It is both tiny and not fully + // vectorizable. + return true; +} + int BoUpSLP::getSpillCost() { // Walk from the bottom of the tree to the top, tracking which values are // live. When we see a call instruction that is not part of our tree, @@ -1816,9 +1874,9 @@ int BoUpSLP::getSpillCost() { ); // Now find the sequence of instructions between PrevInst and Inst. - BasicBlock::reverse_iterator InstIt(Inst->getIterator()), - PrevInstIt(PrevInst->getIterator()); - --PrevInstIt; + BasicBlock::reverse_iterator InstIt = ++Inst->getIterator().getReverse(), + PrevInstIt = + PrevInst->getIterator().getReverse(); while (InstIt != PrevInstIt) { if (PrevInstIt == PrevInst->getParent()->rend()) { PrevInstIt = Inst->getParent()->rbegin(); @@ -1846,14 +1904,6 @@ int BoUpSLP::getTreeCost() { DEBUG(dbgs() << "SLP: Calculating cost for tree of size " << VectorizableTree.size() << ".\n"); - // We only vectorize tiny trees if it is fully vectorizable. - if (VectorizableTree.size() < MinTreeSize && !isFullyVectorizableTinyTree()) { - if (VectorizableTree.empty()) { - assert(!ExternalUses.size() && "We should not have any external users"); - } - return INT_MAX; - } - unsigned BundleWidth = VectorizableTree[0].Scalars.size(); for (TreeEntry &TE : VectorizableTree) { @@ -1882,10 +1932,12 @@ int BoUpSLP::getTreeCost() { auto *VecTy = VectorType::get(EU.Scalar->getType(), BundleWidth); auto *ScalarRoot = VectorizableTree[0].Scalars[0]; if (MinBWs.count(ScalarRoot)) { - auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot]); + auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot].first); + auto Extend = + MinBWs[ScalarRoot].second ? Instruction::SExt : Instruction::ZExt; VecTy = VectorType::get(MinTy, BundleWidth); - ExtractCost += TTI->getExtractWithExtendCost( - Instruction::SExt, EU.Scalar->getType(), VecTy, EU.Lane); + ExtractCost += TTI->getExtractWithExtendCost(Extend, EU.Scalar->getType(), + VecTy, EU.Lane); } else { ExtractCost += TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, EU.Lane); @@ -2182,7 +2234,7 @@ void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL) { // Set the insertion point after the last instruction in the bundle. Set the // debug location to Front. - Builder.SetInsertPoint(BB, next(BasicBlock::iterator(LastInst))); + Builder.SetInsertPoint(BB, ++LastInst->getIterator()); Builder.SetCurrentDebugLocation(Front->getDebugLoc()); } @@ -2383,6 +2435,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { V = Builder.CreateICmp(P0, L, R); E->VectorizedValue = V; + propagateIRFlags(E->VectorizedValue, E->Scalars); ++NumVectorInstructions; return V; } @@ -2440,10 +2493,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *LHS = vectorizeTree(LHSVL); Value *RHS = vectorizeTree(RHSVL); - if (LHS == RHS && isa<Instruction>(LHS)) { - assert((VL0->getOperand(0) == VL0->getOperand(1)) && "Invalid order"); - } - if (Value *V = alreadyVectorized(E->Scalars)) return V; @@ -2593,6 +2642,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { ExternalUses.push_back(ExternalUser(ScalarArg, cast<User>(V), 0)); E->VectorizedValue = V; + propagateIRFlags(E->VectorizedValue, E->Scalars); ++NumVectorInstructions; return V; } @@ -2669,7 +2719,7 @@ Value *BoUpSLP::vectorizeTree() { if (auto *I = dyn_cast<Instruction>(VectorRoot)) Builder.SetInsertPoint(&*++BasicBlock::iterator(I)); auto BundleWidth = VectorizableTree[0].Scalars.size(); - auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot]); + auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot].first); auto *VecTy = VectorType::get(MinTy, BundleWidth); auto *Trunc = Builder.CreateTrunc(VectorRoot, VecTy); VectorizableTree[0].VectorizedValue = Trunc; @@ -2677,6 +2727,16 @@ Value *BoUpSLP::vectorizeTree() { DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size() << " values .\n"); + // If necessary, sign-extend or zero-extend ScalarRoot to the larger type + // specified by ScalarType. + auto extend = [&](Value *ScalarRoot, Value *Ex, Type *ScalarType) { + if (!MinBWs.count(ScalarRoot)) + return Ex; + if (MinBWs[ScalarRoot].second) + return Builder.CreateSExt(Ex, ScalarType); + return Builder.CreateZExt(Ex, ScalarType); + }; + // Extract all of the elements with the external uses. for (const auto &ExternalUse : ExternalUses) { Value *Scalar = ExternalUse.Scalar; @@ -2684,8 +2744,7 @@ Value *BoUpSLP::vectorizeTree() { // Skip users that we already RAUW. This happens when one instruction // has multiple uses of the same value. - if (std::find(Scalar->user_begin(), Scalar->user_end(), User) == - Scalar->user_end()) + if (!is_contained(Scalar->users(), User)) continue; assert(ScalarToTreeEntry.count(Scalar) && "Invalid scalar"); @@ -2712,8 +2771,7 @@ Value *BoUpSLP::vectorizeTree() { Builder.SetInsertPoint(PH->getIncomingBlock(i)->getTerminator()); } Value *Ex = Builder.CreateExtractElement(Vec, Lane); - if (MinBWs.count(ScalarRoot)) - Ex = Builder.CreateSExt(Ex, Scalar->getType()); + Ex = extend(ScalarRoot, Ex, Scalar->getType()); CSEBlocks.insert(PH->getIncomingBlock(i)); PH->setOperand(i, Ex); } @@ -2721,16 +2779,14 @@ Value *BoUpSLP::vectorizeTree() { } else { Builder.SetInsertPoint(cast<Instruction>(User)); Value *Ex = Builder.CreateExtractElement(Vec, Lane); - if (MinBWs.count(ScalarRoot)) - Ex = Builder.CreateSExt(Ex, Scalar->getType()); + Ex = extend(ScalarRoot, Ex, Scalar->getType()); CSEBlocks.insert(cast<Instruction>(User)->getParent()); User->replaceUsesOfWith(Scalar, Ex); } } else { Builder.SetInsertPoint(&F->getEntryBlock().front()); Value *Ex = Builder.CreateExtractElement(Vec, Lane); - if (MinBWs.count(ScalarRoot)) - Ex = Builder.CreateSExt(Ex, Scalar->getType()); + Ex = extend(ScalarRoot, Ex, Scalar->getType()); CSEBlocks.insert(&F->getEntryBlock()); User->replaceUsesOfWith(Scalar, Ex); } @@ -2759,8 +2815,7 @@ Value *BoUpSLP::vectorizeTree() { assert((ScalarToTreeEntry.count(U) || // It is legal to replace users in the ignorelist by undef. - (std::find(UserIgnoreList.begin(), UserIgnoreList.end(), U) != - UserIgnoreList.end())) && + is_contained(UserIgnoreList, U)) && "Replacing out-of-tree value with undef"); } #endif @@ -2853,7 +2908,7 @@ void BoUpSLP::optimizeGatherSequence() { } } if (In) { - assert(std::find(Visited.begin(), Visited.end(), In) == Visited.end()); + assert(!is_contained(Visited, In)); Visited.push_back(In); } } @@ -2994,9 +3049,10 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V) { } // Search up and down at the same time, because we don't know if the new // instruction is above or below the existing scheduling region. - BasicBlock::reverse_iterator UpIter(ScheduleStart->getIterator()); + BasicBlock::reverse_iterator UpIter = + ++ScheduleStart->getIterator().getReverse(); BasicBlock::reverse_iterator UpperEnd = BB->rend(); - BasicBlock::iterator DownIter(ScheduleEnd); + BasicBlock::iterator DownIter = ScheduleEnd->getIterator(); BasicBlock::iterator LowerEnd = BB->end(); for (;;) { if (++ScheduleRegionSize > ScheduleRegionSizeLimit) { @@ -3451,6 +3507,11 @@ void BoUpSLP::computeMinimumValueSizes() { Mask.getBitWidth() - Mask.countLeadingZeros(), MaxBitWidth); } + // True if the roots can be zero-extended back to their original type, rather + // than sign-extended. We know that if the leading bits are not demanded, we + // can safely zero-extend. So we initialize IsKnownPositive to True. + bool IsKnownPositive = true; + // If all the bits of the roots are demanded, we can try a little harder to // compute a narrower type. This can happen, for example, if the roots are // getelementptr indices. InstCombine promotes these indices to the pointer @@ -3462,11 +3523,41 @@ void BoUpSLP::computeMinimumValueSizes() { // compute the number of high-order bits we can truncate. if (MaxBitWidth == DL->getTypeSizeInBits(TreeRoot[0]->getType())) { MaxBitWidth = 8u; + + // Determine if the sign bit of all the roots is known to be zero. If not, + // IsKnownPositive is set to False. + IsKnownPositive = all_of(TreeRoot, [&](Value *R) { + bool KnownZero = false; + bool KnownOne = false; + ComputeSignBit(R, KnownZero, KnownOne, *DL); + return KnownZero; + }); + + // Determine the maximum number of bits required to store the scalar + // values. for (auto *Scalar : ToDemote) { auto NumSignBits = ComputeNumSignBits(Scalar, *DL, 0, AC, 0, DT); auto NumTypeBits = DL->getTypeSizeInBits(Scalar->getType()); MaxBitWidth = std::max<unsigned>(NumTypeBits - NumSignBits, MaxBitWidth); } + + // If we can't prove that the sign bit is zero, we must add one to the + // maximum bit width to account for the unknown sign bit. This preserves + // the existing sign bit so we can safely sign-extend the root back to the + // original type. Otherwise, if we know the sign bit is zero, we will + // zero-extend the root instead. + // + // FIXME: This is somewhat suboptimal, as there will be cases where adding + // one to the maximum bit width will yield a larger-than-necessary + // type. In general, we need to add an extra bit only if we can't + // prove that the upper bit of the original type is equal to the + // upper bit of the proposed smaller type. If these two bits are the + // same (either zero or one) we know that sign-extending from the + // smaller type will result in the same value. Here, since we can't + // yet prove this, we are just making the proposed smaller type + // larger to ensure correctness. + if (!IsKnownPositive) + ++MaxBitWidth; } // Round MaxBitWidth up to the next power-of-two. @@ -3486,7 +3577,7 @@ void BoUpSLP::computeMinimumValueSizes() { // Finally, map the values we can demote to the maximum bit with we computed. for (auto *Scalar : ToDemote) - MinBWs[Scalar] = MaxBitWidth; + MinBWs[Scalar] = std::make_pair(MaxBitWidth, !IsKnownPositive); } namespace { @@ -3642,8 +3733,7 @@ static bool hasValueBeenRAUWed(ArrayRef<Value *> VL, ArrayRef<WeakVH> VH, return !std::equal(VL.begin(), VL.end(), VH.begin()); } -bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, - int CostThreshold, BoUpSLP &R, +bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, unsigned VecRegSize) { unsigned ChainLen = Chain.size(); DEBUG(dbgs() << "SLP: Analyzing a store chain of length " << ChainLen @@ -3672,12 +3762,15 @@ bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, ArrayRef<Value *> Operands = Chain.slice(i, VF); R.buildTree(Operands); + if (R.isTreeTinyAndNotFullyVectorizable()) + continue; + R.computeMinimumValueSizes(); int Cost = R.getTreeCost(); DEBUG(dbgs() << "SLP: Found cost=" << Cost << " for VF=" << VF << "\n"); - if (Cost < CostThreshold) { + if (Cost < -SLPCostThreshold) { DEBUG(dbgs() << "SLP: Decided to vectorize cost=" << Cost << "\n"); R.vectorizeTree(); @@ -3691,7 +3784,7 @@ bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, } bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, - int costThreshold, BoUpSLP &R) { + BoUpSLP &R) { SetVector<StoreInst *> Heads, Tails; SmallDenseMap<StoreInst *, StoreInst *> ConsecutiveChain; @@ -3746,8 +3839,9 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, // FIXME: Is division-by-2 the correct step? Should we assert that the // register size is a power-of-2? - for (unsigned Size = R.getMaxVecRegSize(); Size >= R.getMinVecRegSize(); Size /= 2) { - if (vectorizeStoreChain(Operands, costThreshold, R, Size)) { + for (unsigned Size = R.getMaxVecRegSize(); Size >= R.getMinVecRegSize(); + Size /= 2) { + if (vectorizeStoreChain(Operands, R, Size)) { // Mark the vectorized stores so that we don't vectorize them again. VectorizedStores.insert(Operands.begin(), Operands.end()); Changed = true; @@ -3805,11 +3899,12 @@ bool SLPVectorizerPass::tryToVectorizePair(Value *A, Value *B, BoUpSLP &R) { bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, ArrayRef<Value *> BuildVector, - bool allowReorder) { + bool AllowReorder) { if (VL.size() < 2) return false; - DEBUG(dbgs() << "SLP: Vectorizing a list of length = " << VL.size() << ".\n"); + DEBUG(dbgs() << "SLP: Trying to vectorize a list of length = " << VL.size() + << ".\n"); // Check that all of the parts are scalar instructions of the same type. Instruction *I0 = dyn_cast<Instruction>(VL[0]); @@ -3818,10 +3913,11 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, unsigned Opcode0 = I0->getOpcode(); - // FIXME: Register size should be a parameter to this function, so we can - // try different vectorization factors. unsigned Sz = R.getVectorElementSize(I0); - unsigned VF = R.getMinVecRegSize() / Sz; + unsigned MinVF = std::max(2U, R.getMinVecRegSize() / Sz); + unsigned MaxVF = std::max<unsigned>(PowerOf2Floor(VL.size()), MinVF); + if (MaxVF < 2) + return false; for (Value *V : VL) { Type *Ty = V->getType(); @@ -3837,70 +3933,89 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, // Keep track of values that were deleted by vectorizing in the loop below. SmallVector<WeakVH, 8> TrackValues(VL.begin(), VL.end()); - for (unsigned i = 0, e = VL.size(); i < e; ++i) { - unsigned OpsWidth = 0; + unsigned NextInst = 0, MaxInst = VL.size(); + for (unsigned VF = MaxVF; NextInst + 1 < MaxInst && VF >= MinVF; + VF /= 2) { + // No actual vectorization should happen, if number of parts is the same as + // provided vectorization factor (i.e. the scalar type is used for vector + // code during codegen). + auto *VecTy = VectorType::get(VL[0]->getType(), VF); + if (TTI->getNumberOfParts(VecTy) == VF) + continue; + for (unsigned I = NextInst; I < MaxInst; ++I) { + unsigned OpsWidth = 0; - if (i + VF > e) - OpsWidth = e - i; - else - OpsWidth = VF; + if (I + VF > MaxInst) + OpsWidth = MaxInst - I; + else + OpsWidth = VF; - if (!isPowerOf2_32(OpsWidth) || OpsWidth < 2) - break; + if (!isPowerOf2_32(OpsWidth) || OpsWidth < 2) + break; - // Check that a previous iteration of this loop did not delete the Value. - if (hasValueBeenRAUWed(VL, TrackValues, i, OpsWidth)) - continue; + // Check that a previous iteration of this loop did not delete the Value. + if (hasValueBeenRAUWed(VL, TrackValues, I, OpsWidth)) + continue; - DEBUG(dbgs() << "SLP: Analyzing " << OpsWidth << " operations " - << "\n"); - ArrayRef<Value *> Ops = VL.slice(i, OpsWidth); - - ArrayRef<Value *> BuildVectorSlice; - if (!BuildVector.empty()) - BuildVectorSlice = BuildVector.slice(i, OpsWidth); - - R.buildTree(Ops, BuildVectorSlice); - // TODO: check if we can allow reordering also for other cases than - // tryToVectorizePair() - if (allowReorder && R.shouldReorder()) { - assert(Ops.size() == 2); - assert(BuildVectorSlice.empty()); - Value *ReorderedOps[] = { Ops[1], Ops[0] }; - R.buildTree(ReorderedOps, None); - } - R.computeMinimumValueSizes(); - int Cost = R.getTreeCost(); + DEBUG(dbgs() << "SLP: Analyzing " << OpsWidth << " operations " + << "\n"); + ArrayRef<Value *> Ops = VL.slice(I, OpsWidth); + + ArrayRef<Value *> BuildVectorSlice; + if (!BuildVector.empty()) + BuildVectorSlice = BuildVector.slice(I, OpsWidth); + + R.buildTree(Ops, BuildVectorSlice); + // TODO: check if we can allow reordering for more cases. + if (AllowReorder && R.shouldReorder()) { + // Conceptually, there is nothing actually preventing us from trying to + // reorder a larger list. In fact, we do exactly this when vectorizing + // reductions. However, at this point, we only expect to get here from + // tryToVectorizePair(). + assert(Ops.size() == 2); + assert(BuildVectorSlice.empty()); + Value *ReorderedOps[] = {Ops[1], Ops[0]}; + R.buildTree(ReorderedOps, None); + } + if (R.isTreeTinyAndNotFullyVectorizable()) + continue; - if (Cost < -SLPCostThreshold) { - DEBUG(dbgs() << "SLP: Vectorizing list at cost:" << Cost << ".\n"); - Value *VectorizedRoot = R.vectorizeTree(); - - // Reconstruct the build vector by extracting the vectorized root. This - // way we handle the case where some elements of the vector are undefined. - // (return (inserelt <4 xi32> (insertelt undef (opd0) 0) (opd1) 2)) - if (!BuildVectorSlice.empty()) { - // The insert point is the last build vector instruction. The vectorized - // root will precede it. This guarantees that we get an instruction. The - // vectorized tree could have been constant folded. - Instruction *InsertAfter = cast<Instruction>(BuildVectorSlice.back()); - unsigned VecIdx = 0; - for (auto &V : BuildVectorSlice) { - IRBuilder<NoFolder> Builder(InsertAfter->getParent(), - ++BasicBlock::iterator(InsertAfter)); - Instruction *I = cast<Instruction>(V); - assert(isa<InsertElementInst>(I) || isa<InsertValueInst>(I)); - Instruction *Extract = cast<Instruction>(Builder.CreateExtractElement( - VectorizedRoot, Builder.getInt32(VecIdx++))); - I->setOperand(1, Extract); - I->removeFromParent(); - I->insertAfter(Extract); - InsertAfter = I; + R.computeMinimumValueSizes(); + int Cost = R.getTreeCost(); + + if (Cost < -SLPCostThreshold) { + DEBUG(dbgs() << "SLP: Vectorizing list at cost:" << Cost << ".\n"); + Value *VectorizedRoot = R.vectorizeTree(); + + // Reconstruct the build vector by extracting the vectorized root. This + // way we handle the case where some elements of the vector are + // undefined. + // (return (inserelt <4 xi32> (insertelt undef (opd0) 0) (opd1) 2)) + if (!BuildVectorSlice.empty()) { + // The insert point is the last build vector instruction. The + // vectorized root will precede it. This guarantees that we get an + // instruction. The vectorized tree could have been constant folded. + Instruction *InsertAfter = cast<Instruction>(BuildVectorSlice.back()); + unsigned VecIdx = 0; + for (auto &V : BuildVectorSlice) { + IRBuilder<NoFolder> Builder(InsertAfter->getParent(), + ++BasicBlock::iterator(InsertAfter)); + Instruction *I = cast<Instruction>(V); + assert(isa<InsertElementInst>(I) || isa<InsertValueInst>(I)); + Instruction *Extract = + cast<Instruction>(Builder.CreateExtractElement( + VectorizedRoot, Builder.getInt32(VecIdx++))); + I->setOperand(1, Extract); + I->removeFromParent(); + I->insertAfter(Extract); + InsertAfter = I; + } } + // Move to the next bundle. + I += VF - 1; + NextInst = I + 1; + Changed = true; } - // Move to the next bundle. - i += VF - 1; - Changed = true; } } @@ -3973,7 +4088,7 @@ static Value *createRdxShuffleMask(unsigned VecLen, unsigned NumEltsToRdx, return ConstantVector::get(ShuffleMask); } - +namespace { /// Model horizontal reductions. /// /// A horizontal reduction is a tree of reduction operations (currently add and @@ -4006,7 +4121,14 @@ class HorizontalReduction { SmallVector<Value *, 32> ReducedVals; BinaryOperator *ReductionRoot; - PHINode *ReductionPHI; + // After successfull horizontal reduction vectorization attempt for PHI node + // vectorizer tries to update root binary op by combining vectorized tree and + // the ReductionPHI node. But during vectorization this ReductionPHI can be + // vectorized itself and replaced by the undef value, while the instruction + // itself is marked for deletion. This 'marked for deletion' PHI node then can + // be used in new binary operation, causing "Use still stuck around after Def + // is destroyed" crash upon PHI node deletion. + WeakVH ReductionPHI; /// The opcode of the reduction. unsigned ReductionOpcode; @@ -4025,14 +4147,13 @@ public: unsigned MinVecRegSize; HorizontalReduction(unsigned MinVecRegSize) - : ReductionRoot(nullptr), ReductionPHI(nullptr), ReductionOpcode(0), - ReducedValueOpcode(0), IsPairwiseReduction(false), ReduxWidth(0), + : ReductionRoot(nullptr), ReductionOpcode(0), ReducedValueOpcode(0), + IsPairwiseReduction(false), ReduxWidth(0), MinVecRegSize(MinVecRegSize) {} /// \brief Try to find a reduction tree. bool matchAssociativeReduction(PHINode *Phi, BinaryOperator *B) { - assert((!Phi || - std::find(Phi->op_begin(), Phi->op_end(), B) != Phi->op_end()) && + assert((!Phi || is_contained(Phi->operands(), B)) && "Thi phi needs to use the binary operator"); // We could have a initial reductions that is not an add. @@ -4113,12 +4234,21 @@ public: // Visit left or right. Value *NextV = TreeN->getOperand(EdgeToVist); - // We currently only allow BinaryOperator's and SelectInst's as reduction - // values in our tree. - if (isa<BinaryOperator>(NextV) || isa<SelectInst>(NextV)) - Stack.push_back(std::make_pair(cast<Instruction>(NextV), 0)); - else if (NextV != Phi) + if (NextV != Phi) { + auto *I = dyn_cast<Instruction>(NextV); + // Continue analysis if the next operand is a reduction operation or + // (possibly) a reduced value. If the reduced value opcode is not set, + // the first met operation != reduction operation is considered as the + // reduced value class. + if (I && (!ReducedValueOpcode || I->getOpcode() == ReducedValueOpcode || + I->getOpcode() == ReductionOpcode)) { + if (!ReducedValueOpcode && I->getOpcode() != ReductionOpcode) + ReducedValueOpcode = I->getOpcode(); + Stack.push_back(std::make_pair(I, 0)); + continue; + } return false; + } } return true; } @@ -4141,7 +4271,15 @@ public: unsigned i = 0; for (; i < NumReducedVals - ReduxWidth + 1; i += ReduxWidth) { - V.buildTree(makeArrayRef(&ReducedVals[i], ReduxWidth), ReductionOps); + auto VL = makeArrayRef(&ReducedVals[i], ReduxWidth); + V.buildTree(VL, ReductionOps); + if (V.shouldReorder()) { + SmallVector<Value *, 8> Reversed(VL.rbegin(), VL.rend()); + V.buildTree(Reversed, ReductionOps); + } + if (V.isTreeTinyAndNotFullyVectorizable()) + continue; + V.computeMinimumValueSizes(); // Estimate cost. @@ -4175,7 +4313,7 @@ public: ReducedVals[i]); } // Update users. - if (ReductionPHI) { + if (ReductionPHI && !isa<UndefValue>(ReductionPHI)) { assert(ReductionRoot && "Need a reduction operation"); ReductionRoot->setOperand(0, VectorizedTree); ReductionRoot->setOperand(1, ReductionPHI); @@ -4202,7 +4340,8 @@ private: int VecReduxCost = IsPairwiseReduction ? PairwiseRdxCost : SplittingRdxCost; int ScalarReduxCost = - ReduxWidth * TTI->getArithmeticInstrCost(ReductionOpcode, VecTy); + (ReduxWidth - 1) * + TTI->getArithmeticInstrCost(ReductionOpcode, ScalarTy); DEBUG(dbgs() << "SLP: Adding cost " << VecReduxCost - ScalarReduxCost << " for reduction that starts with " << *FirstReducedVal @@ -4254,6 +4393,7 @@ private: return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); } }; +} // end anonymous namespace /// \brief Recognize construction of vectors like /// %ra = insertelement <4 x float> undef, float %s0, i32 0 @@ -4354,7 +4494,7 @@ static Value *getReductionValue(const DominatorTree *DT, PHINode *P, return nullptr; // There is a loop latch, return the incoming value if it comes from - // that. This reduction pattern occassionaly turns up. + // that. This reduction pattern occasionally turns up. if (P->getIncomingBlock(0) == BBLatch) { Rdx = P->getIncomingValue(0); } else if (P->getIncomingBlock(1) == BBLatch) { @@ -4510,8 +4650,10 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { if (BinaryOperator *BinOp = dyn_cast<BinaryOperator>(RI->getOperand(0))) { DEBUG(dbgs() << "SLP: Found a return to vectorize.\n"); - if (tryToVectorizePair(BinOp->getOperand(0), - BinOp->getOperand(1), R)) { + if (canMatchHorizontalReduction(nullptr, BinOp, R, TTI, + R.getMinVecRegSize()) || + tryToVectorizePair(BinOp->getOperand(0), BinOp->getOperand(1), + R)) { Changed = true; it = BB->begin(); e = BB->end(); @@ -4690,8 +4832,7 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { // may cause a significant compile-time increase. for (unsigned CI = 0, CE = it->second.size(); CI < CE; CI+=16) { unsigned Len = std::min<unsigned>(CE - CI, 16); - Changed |= vectorizeStores(makeArrayRef(&it->second[CI], Len), - -SLPCostThreshold, R); + Changed |= vectorizeStores(makeArrayRef(&it->second[CI], Len), R); } } return Changed; |